Source code for wdoc.utils.customs.litellm_embeddings

"""
Custom embeddings to use litellm. This allows using for example
"ollama/bge-m3" as a model name.
Source: https://python.langchain.com/docs/how_to/custom_embeddings/
"""

from typing import List, Optional

from langchain_core.embeddings import Embeddings


[docs] class LiteLLMEmbeddings(Embeddings): """Litellm embedding model integration.""" def __init__( self, model: str, dimensions: Optional[int], api_base: Optional[str], private: bool, **embed_kwargs, ): import litellm global litellm assert "/" in model, ( "model must contain a /, for example 'ollama/bge-m3' or 'openai/text-embedding-ada-002'" ) if private: if not api_base: assert any( provider in model for provider in ["ollama", "huggingface"] ), ( "--private argument is set and api_base not overridden BUT the model does not contain ollama nor huggingface, this can be a mistake so crashing out of abundance of caution. If you think this is a bug please open an issue on github." ) self.model = model self.dimensions = dimensions self.private = private self.api_base = api_base self.embed_kwargs = embed_kwargs
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" assert not any(not t.strip() for t in texts), ( f"The texts to embed include an empty string, which usually errors out providers. Texts={texts}" ) # https://docs.litellm.ai/docs/embedding/supported_embedding vecs = litellm.embedding( model=self.model, input=texts, dimensions=self.dimensions, encoding_format="float", timeout=600, api_base=self.api_base, user="wdoc_embeddings", drop_params=True, # 'sentence-similarity', 'feature-extraction', 'rerank', 'embed', 'similarity' # input_type="feature-extraction", # seems to crash for openai **self.embed_kwargs, ) if hasattr( vecs, "data" ): # must an EmbeddingsResponse format, for example ollama data = vecs.data if isinstance(data, list) and isinstance(data[0], dict): vecs = [v["embedding"] for v in data] elif isinstance(data, list) and hasattr(data[0], "embedding"): vecs = [v.embedding for v in data] else: raise Exception( f"Failed to parsed output of litellm embedding for model '{self.model}'. String rendering is '{vecs}'. Please open a github issue to get that fixed." ) return vecs
[docs] def embed_query(self, text: str) -> List[float]: """Embed query text.""" return self.embed_documents([text])[0]