"""
Retrievers used to retrieve the appropriate embeddings for a given query.
"""
from beartype.typing import Any, List, Optional
from langchain_core.documents import Document
from loguru import logger
# from langchain.storage import LocalFileStore
from langchain_core.retrievers import BaseRetriever
from langchain_core.embeddings import Embeddings
from wdoc.utils.env import env
from wdoc.utils.tasks.types import wdocTask
from wdoc.utils.misc import cache_dir, get_splitter
from wdoc.utils.prompts import multiquery_parser, prompts
from wdoc.utils.customs.compressed_embeddings_cacher import LocalFileStore
[docs]
def create_multiquery_retriever(
llm: "langchain_litellm.ChatLiteLLM",
retriever: BaseRetriever,
) -> BaseRetriever:
# advanced mode using pydantic parsers
llm_chain = prompts.multiquery | llm | multiquery_parser
from langchain_classic.retrievers.multi_query import MultiQueryRetriever
mqr = MultiQueryRetriever(
retriever=retriever,
llm_chain=llm_chain,
)
# TODO: fix the fallback: the llm_chain has to have a callback instead
# # as pydantic parsing can be complicated for some model
# # we keep the default multi query retriever as a fallback
# default = MultiQueryRetriever.from_llm(
# retriever=retriever,
# llm=llm,
# )
# resilient = mqr.with_fallbacks(fallbacks=[default])
# return resilient
return mqr
[docs]
def create_parent_retriever(
task: wdocTask,
loaded_embeddings: Any,
loaded_docs: List[Document],
top_k: int,
relevancy: float,
) -> BaseRetriever:
"https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever"
csp = get_splitter(task)
psp = get_splitter(task)
psp._chunk_size *= 4
lfs = LocalFileStore(
database_path=cache_dir / "parent_retriever",
verbose=env.WDOC_VERBOSE,
name="parent_retriever",
)
from langchain_classic.retrievers import ParentDocumentRetriever
parent = ParentDocumentRetriever(
vectorstore=loaded_embeddings,
docstore=lfs,
child_splitter=csp,
parent_splitter=psp,
search_type="similarity",
search_kwargs={
"k": top_k,
"score_threshold": relevancy,
},
)
parent.add_documents(loaded_docs)
return parent
[docs]
def get_all_texts(loaded_embeddings: Embeddings) -> List[str]:
return [v.page_content for k, v in loaded_embeddings.docstore._dict.items()]
[docs]
def create_retrievers(
query_retrievers: str,
loaded_embeddings,
embedding_engine,
llm,
top_k: int,
relevancy: float,
task: wdocTask,
loaded_docs: Optional[List[Document]],
) -> BaseRetriever:
"""Create and return list of retrievers based on query_retrievers setting."""
retrievers = []
all_texts = None
if "multiquery" in query_retrievers.lower():
retrievers.append(
create_multiquery_retriever(
llm=llm,
retriever=loaded_embeddings.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": top_k,
"score_threshold": relevancy,
},
),
)
)
if "knn" in query_retrievers.lower():
if not all_texts:
all_texts = get_all_texts(loaded_embeddings)
from langchain_community.retrievers import KNNRetriever
retrievers.append(
KNNRetriever.from_texts(
all_texts,
embedding_engine,
relevancy_threshold=relevancy,
k=top_k,
)
)
if "svm" in query_retrievers:
if not all_texts:
all_texts = get_all_texts(loaded_embeddings)
from langchain_community.retrievers import SVMRetriever
retrievers.append(
SVMRetriever.from_texts(
all_texts,
embedding_engine,
relevancy_threshold=relevancy,
k=top_k,
)
)
if "parent" in query_retrievers.lower():
if not loaded_docs:
logger.warning(
"To use the 'parent' retriever, we have have loaded documents but we haven't. This might be because you are loading from an index directly instead of creating embeddings during this run. As an experimental workaround, we load the documents from the loaded embeddings."
)
loaded_docs = list(loaded_embeddings.docstore._dict.values())
retrievers.append(
create_parent_retriever(
task=task,
loaded_embeddings=loaded_embeddings,
loaded_docs=loaded_docs,
top_k=top_k,
relevancy=relevancy,
)
)
if "basic" in query_retrievers.lower():
retrievers.append(
loaded_embeddings.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": top_k,
"score_threshold": relevancy,
},
)
)
assert retrievers, (
"No retriever selected. Probably cause by a wrong cli_command or query_retrievers arg."
)
if len(retrievers) == 1:
retriever = retrievers[0]
else:
from langchain_classic.retrievers.merger_retriever import MergerRetriever
merge_retriever = MergerRetriever(retrievers=retrievers)
# remove redundant results from the merged retrievers:
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_classic.retrievers.document_compressors import (
DocumentCompressorPipeline,
)
from langchain_classic.retrievers import ContextualCompressionRetriever
filtered = EmbeddingsRedundantFilter(
embeddings=embedding_engine,
similarity_threshold=0.999,
)
filter_pipeline = DocumentCompressorPipeline(transformers=[filtered])
retriever = ContextualCompressionRetriever(
base_compressor=filter_pipeline, base_retriever=merge_retriever
)
return retriever