Source code for wdoc.utils.embeddings

"""
* Class used to create the embeddings.
* Loads and store embeddings for each document.
"""

# import math
import os
import time
from pathlib import Path

import numpy as np
from beartype.typing import Any, List, Optional, Union
from joblib import Parallel, delayed
from langchain_classic.embeddings import CacheBackedEmbeddings
from langchain_core.vectorstores.base import VectorStore
from langchain_core.embeddings import Embeddings
from tqdm.asyncio import tqdm
from loguru import logger

# from langchain_classic.storage import LocalFileStore
from wdoc.utils.customs.compressed_embeddings_cacher import LocalFileStore
from wdoc.utils.env import env
from wdoc.utils.misc import ModelName, cache_dir, get_tkn_length, cache_file_in_memory

embeddings_cache_dir = cache_dir / "CacheEmbedding"
embeddings_cache_dir.mkdir(exist_ok=True)

# Source: https://api.python.langchain.com/en/latest/_modules/langchain_community/embeddings/huggingface.html#HuggingFaceEmbeddings
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
    "Represent the question for retrieving supporting documents: "
)


def __get_faiss_vectorstore__():
    """Returns either the FAISS vectorstore class or the custom BinaryFAISS.
    This way we can modify the env variable WDOC_FAISS_BINARY after
    importing wdoc.
    """
    if env.WDOC_FAISS_BINARY:
        assert not env.WDOC_MOD_FAISS_SCORE_FN, (
            "You can't use the env variable WDOC_MOD_FAISS_SCORE_FN=true and WDOC_FAISS_BINARY=true at the same time."
        )
        assert env.WDOC_FAISS_COMPRESSION, (
            "You can't use the env variable WDOC_FAISS_BINARY=true and WDOC_FAISS_COMPRESSION=false at the same time."
        )
        from wdoc.utils.customs.binary_faiss_vectorstore import BinaryFAISS

        return BinaryFAISS
    else:
        if env.WDOC_FAISS_COMPRESSION:
            from wdoc.utils.customs.binary_faiss_vectorstore import CompressedFAISS

            return CompressedFAISS
        else:
            from langchain_community.vectorstores import FAISS

            return FAISS


[docs] def faiss_custom_score_function(distance: float) -> float: """ Scoring function for faiss to make sure it's positive. Related issue: https://github.com/langchain-ai/langchain/issues/17333 In langchain the default value is the euclidean relevance score: return 1.0 - distance / math.sqrt(2) The output is a similarity score: it must be [0,1] such that 0 is the most dissimilar, 1 is the most similar document. """ # To disable it but simply check: uncomment this and add "import math" # assert distance >= 0, distance # return 1.0 - distance / math.sqrt(2) new = 1 - ((1 + distance) / 2) return new
[docs] def load_embeddings_engine( modelname: ModelName, cli_kwargs: dict, api_base: Optional[str], embed_kwargs: dict, private: bool, do_test: bool, ) -> Embeddings: """ Create the Embeddings class used to compute embeddings. This class is wrapped into a CacheBackedEmbeddings to add a caching layer. """ from wdoc.utils.customs.litellm_embeddings import LiteLLMEmbeddings logger.debug("Loading the embeddings engine") if "embed_instruct" in cli_kwargs and cli_kwargs["embed_instruct"]: instruct = True else: instruct = False logger.debug( f"Selected embedding model '{modelname}' of backend {modelname.backend}" ) try: embeddings = LiteLLMEmbeddings( model=modelname.original, dimensions=env.WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None api_base=api_base, private=private, **embed_kwargs, ) if do_test: test_embeddings(embeddings) except Exception as e: logger.warning( f"Failed to use the experimental LiteLLMEmbeddings backend, defaulting to using the previous implementation. Error was '{e}'. Please open a github issue to help the developper debug this until it is stable enough." ) if "embeddings" in locals(): # already loaded pass elif modelname.backend == "openai": if private: assert api_base, "If private is set, api_base must be set too" else: assert ( "OPENAI_API_KEY" in os.environ and os.environ["OPENAI_API_KEY"] and "REDACTED" not in os.environ["OPENAI_API_KEY"] ), "Missing OPENAI_API_KEY" from langchain_openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings( model=modelname.model, openai_api_key=os.environ["OPENAI_API_KEY"], api_base=api_base, dimensions=env.WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None **embed_kwargs, ) elif modelname.backend == "huggingface": assert not private, ( "Set private but tried to use huggingface embeddings, which might not be as private as using sentencetransformers" ) model_kwargs = { "device": "cpu", # "device": "cuda", } model_kwargs.update(embed_kwargs) if modelname.backend == "google" and "gemma" in modelname.model.lower(): assert ( "HUGGINGFACE_API_KEY" in os.environ and os.environ["HUGGINGFACE_API_KEY"] and "REDACTED" not in os.environ["HUGGINGFACE_API_KEY"] ), "Missing HUGGINGFACE_API_KEY" hftkn = os.environ["HUGGINGFACE_API_KEY"] # your token to use the models model_kwargs["use_auth_token"] = hftkn if instruct: from langchain_community.embeddings import HuggingFaceInstructEmbeddings embeddings = HuggingFaceInstructEmbeddings( model_name=modelname.model, model_kwargs=model_kwargs, embed_instruction=DEFAULT_EMBED_INSTRUCTION, query_instruction=DEFAULT_QUERY_INSTRUCTION, ) else: from langchain_community.embeddings import HuggingFaceEmbeddings embeddings = HuggingFaceEmbeddings( model_name=modelname.model, model_kwargs=model_kwargs, ) if modelname.backend == "google" and "gemma" in modelname.model.lower(): # please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` # or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[pad]'}) embeddings.client.tokenizer.pad_token = ( embeddings.client.tokenizer.eos_token ) elif modelname.backend == "sentencetransformers": if private: logger.warning("Private is set and will use sentencetransformers backend") embed_kwargs.update( { "batch_size": 1, "device": None, } ) from langchain_community.embeddings import SentenceTransformerEmbeddings embeddings = SentenceTransformerEmbeddings( model_name=modelname.model, encode_kwargs=embed_kwargs, ) else: raise ValueError(f"Invalid embedding backend: {modelname.backend}") if do_test: try: test_embeddings(embeddings) except Exception as e: logger.warning( f"Error when testing embeddings, something is probably wrong with the backend. Error is '{e}'. Please open a github issue to help the developper" ) lfs = LocalFileStore( database_path=embeddings_cache_dir / modelname.sanitized, expiration_days=env.WDOC_EXPIRE_CACHE_DAYS, verbose=env.WDOC_VERBOSE, name="Embeddings_" + modelname.sanitized, ) if env.WDOC_DISABLE_EMBEDDINGS_CACHE: logger.info( "Embeddings cache is disabled - using direct embeddings without caching" ) cached_embeddings = embeddings else: cache_content = list(lfs.yield_keys()) logger.info(f"Found {len(cache_content)} embeddings in local cache") cached_embeddings = CacheBackedEmbeddings.from_bytes_store( underlying_embeddings=embeddings, document_embedding_cache=lfs, namespace=modelname.sanitized, ) if do_test: try: test_embeddings(cached_embeddings) except Exception as e: logger.warning( f"Error when testing embeddings after loading the cache, something is probably wrong with the backend. Error is '{e}'. Please open a github issue to help the developper" ) logger.debug("Done loading cached embeddings") return cached_embeddings
[docs] def create_embeddings( modelname: ModelName, cached_embeddings: Embeddings, save_embeds_as: Union[str, Path], load_embeds_from: Optional[Union[str, Path]], loaded_docs: Any, dollar_limit: Union[int, float], private: bool, ) -> VectorStore: """ For each document of loaded_docs, we check if the embeddings were already computed and present in the cache or ask the CacheBackedEmbeddings class to create them and return to wdoc.loaded_embeddings. """ import litellm logger.debug("Creating embeddings") # reload passed embeddings if load_embeds_from: logger.warning("Reloading documents and embeddings from file") path = Path(load_embeds_from) assert path.exists(), f"file not found at '{path}'" cache_file_in_memory(path, recursive=True) db = __get_faiss_vectorstore__().load_local( str(path), cached_embeddings, relevance_score_fn=( faiss_custom_score_function if env.WDOC_MOD_FAISS_SCORE_FN else None ), allow_dangerous_deserialization=True, ) n_doc = len(db.index_to_docstore_id.keys()) logger.warning(f"Loaded {n_doc} documents") return db db = None ti = time.time() docs = loaded_docs logger.info(f"Docs to embed: {len(docs)}") # check price of embedding full_tkn = sum([get_tkn_length(doc) for doc in docs]) logger.info(f"Total number of tokens in documents: '{full_tkn}'") if modelname.backend in [ "ollama", "huggingface", "sentence-transformers", "sentencetransformers", ]: price = 0 logger.info("Local embedding model detected, setting the price to 0") else: if private: logger.info("Not checking token price because private is set") price = 0 elif modelname.original in litellm.model_cost: price = litellm.model_cost[modelname.original]["input_cost_per_token"] assert litellm.model_cost[modelname.original]["output_cost_per_token"] == 0 elif modelname.model in litellm.model_cost: price = litellm.model_cost[modelname.model]["input_cost_per_token"] assert litellm.model_cost[modelname.model]["output_cost_per_token"] == 0 else: logger.warning( f"Couldn't find the price of embedding model {modelname.original}. Assuming the cost is zero" ) price = 0 dol_price = full_tkn * price logger.warning(f"Total cost to embed all tokens is ${dol_price:.6f}") if dol_price > dollar_limit: ans = input("Do you confirm you are okay to pay this? (y/n)\n>") if ans.lower() not in ["y", "yes"]: logger.warning("Quitting.") raise SystemExit() def embed_one_batch( batch: List, ib: int, ) -> VectorStore: n_trial = 3 for trial in range(n_trial): # logger.info(f"Embedding batch #{ib + 1}") try: temp = __get_faiss_vectorstore__().from_documents( batch, cached_embeddings, normalize_L2=False if env.WDOC_FAISS_BINARY else True, relevance_score_fn=( faiss_custom_score_function if env.WDOC_MOD_FAISS_SCORE_FN else None ), ) break except Exception as e: logger.warning( f"Thread #{ib + 1} Error at trial {trial + 1}/{n_trial} when trying to embed documents: {e}" ) if trial + 1 >= n_trial: if env.WDOC_DISABLE_EMBEDDINGS_CACHE: logger.exception( "Too many errors when asking provider for embeddings but no cache to bypass so crashing" ) raise logger.warning( "Too many errors when asking provider for embeddings so try bypassing the embeddings cache" ) temp = __get_faiss_vectorstore__().from_documents( batch, cached_embeddings.underlying_embeddings, normalize_L2=False if env.WDOC_FAISS_BINARY else True, relevance_score_fn=( faiss_custom_score_function if env.WDOC_MOD_FAISS_SCORE_FN else None ), ) break else: time.sleep(1) return temp # create a faiss index for batch of documents # Create batches based on token count (max 100k tokens) and document count (max 1000 docs) max_tokens_per_batch = 100_000 max_docs_per_batch = 1000 batches = [] current_batch_start = 0 current_token_count = 0 current_doc_count = 0 for i, doc in enumerate(docs): doc_tokens = get_tkn_length(doc) # Check if adding this doc would exceed limits if current_doc_count > 0 and ( current_token_count + doc_tokens > max_tokens_per_batch or current_doc_count >= max_docs_per_batch ): # Save current batch and start a new one batches.append([current_batch_start, i]) current_batch_start = i current_token_count = doc_tokens current_doc_count = 1 else: current_token_count += doc_tokens current_doc_count += 1 # Add the final batch if current_doc_count > 0: batches.append([current_batch_start, len(docs)]) temp_dbs = Parallel( backend="threading", n_jobs=10, verbose=0 if not env.WDOC_VERBOSE else 51, )( delayed(embed_one_batch)( batch=docs[batch[0] : batch[1]], ib=ib, ) for ib, batch in tqdm( enumerate(batches), total=len(batches), desc="Embedding by batch", # disable=not env.WDOC_VERBOSE, ) ) for temp in temp_dbs: if not db: db = temp else: db.merge_from(temp) logger.info(f"Done creating index (total time: {time.time() - ti:.2f}s)") # saving embeddings logger.debug("Saving embeddings to file") db.save_local(save_embeds_as) logger.debug("Done saving embeddings to file") logger.debug("Done creating embeddings") return db
[docs] def test_embeddings(embeddings: Embeddings) -> None: "Simple testing of embeddings to know early if something seems wrong" logger.debug("Testing embeddings") vec1 = np.array(embeddings.embed_query("This is a test")) vec2 = np.array(embeddings.embed_documents(["This is another test"])[0]) shape1 = vec1.shape shape2 = vec2.shape assert shape1 == shape2, ( f"Test vectors 1 has shape {shape1} but vector 2 has shape {shape2}" ) assert not (vec1 == vec2).all(), ( "Test vectors 1 and 2 are identical despite different inputs" ) assert not ((vec1 == 0).all() or (vec2 == 0).all()), ( "Test vectors 1 or 2 or both is only zeroes" ) logger.debug("Done testing embeddings")