Source code for wdoc.utils.tasks.query

"""
Chain (logic) used to query a document.
"""

import re
import time

import numpy as np
from beartype.typing import List, Optional, Union
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import chain
from langchain_core.runnables.base import RunnableLambda
from tqdm.asyncio import tqdm
from loguru import logger

from wdoc.utils.env import env
from wdoc.utils.errors import (
    InvalidDocEvaluationByLLMEval,
    NoDocumentsAfterLLMEvalFiltering,
    NoDocumentsRetrieved,
    ShouldIncreaseTopKAfterLLMEvalFiltering,
)
from wdoc.utils.misc import get_tkn_length, thinking_answer_parser, log_and_time_fn

irrelevant_regex = re.compile(r"\bIRRELEVANT\b")


[docs] @log_and_time_fn def check_intermediate_answer(ans: str) -> bool: "filters out the intermediate answers that are deemed irrelevant." if "<answer>IRRELEVANT</answer>" in ans: return False if ((not irrelevant_regex.search(ans)) and len(ans) < len("IRRELEVANT") * 2) or len( ans ) >= len("IRRELEVANT") * 2: return True return False
[docs] @log_and_time_fn def sieve_documents(instance) -> RunnableLambda: """cap the number of retrieved documents as if multiple retrievers are used we can end up with a lot more document! """ def _sieve(inputs: dict) -> dict: assert "question_to_answer" in inputs, inputs.keys() assert "unfiltered_docs" in inputs, inputs.keys() # we have to pass an instance otherwise we can't know if the top_k got updated assert hasattr(instance, "top_k") assert hasattr(instance, "max_top_k") if instance.max_top_k: assert instance.max_top_k >= instance.top_k if len(inputs) > instance.top_k: logger.warning( "Number of documents found via embeddings was " f"'{inputs['unfiltered_docs']}' which is > top_k ({instance.top_k}) " "so we crop" ) inputs["unfiltered_docs"] = inputs["unfiltered_docs"][: instance.top_k] return inputs _sieve = chain(_sieve) return _sieve
@log_and_time_fn def refilter_docs(inputs: dict) -> List[Document]: "filter documents fond via RAG based on the digit answered by the eval llm" unfiltered_docs = inputs["unfiltered_docs"] evaluations = inputs["evaluations"] assert isinstance(unfiltered_docs, list), ( f"unfiltered_docs should be a list, not {type(unfiltered_docs)}" ) assert isinstance(evaluations, list), ( f"evaluations should be a list, not {type(evaluations)}" ) assert len(unfiltered_docs) == len(evaluations), ( f"len of unfiltered_docs is {len(unfiltered_docs)} but len of evaluations is {len(evaluations)}" ) if not unfiltered_docs: raise NoDocumentsRetrieved("No document corresponding to the query") filtered_docs = [] for ie, evals in enumerate(evaluations): # iterating over each document if not isinstance(evals, list): evals = [evals] answers = [thinking_answer_parser(ev)["answer"] for ev in evals] for ia, a in enumerate(answers): try: a = int(a) except Exception as err: logger.warning( f"Document was not evaluated with a number: '{err}' for answer '{a}'\nKeeping the document anyway." ) a = 5 answers[ia] = a if sum(answers) / len(answers) >= 3: filtered_docs.append(unfiltered_docs[ie]) if not filtered_docs: raise NoDocumentsAfterLLMEvalFiltering( "No document remained after filtering with the query" ) return filtered_docs refilter_docs = chain(refilter_docs)
[docs] @log_and_time_fn def retrieve_documents_for_query(retriever): """ Create a retrieve documents chain for query tasks. Parameters ---------- retriever : object The retriever object to use for document retrieval. Returns ------- RunnableLambda A chain that retrieves documents using the provided retriever. """ def _retrieve_documents(inputs): return { "unfiltered_docs": retriever.invoke(inputs["question_for_embedding"]), "question_to_answer": inputs["question_to_answer"], } _retrieve_documents = chain(_retrieve_documents) return _retrieve_documents
[docs] @log_and_time_fn def parse_eval_output(output: str) -> str: """ Parse an LLM's answer about wether a document is relevant or not into an integer from 0 to 10 as str. For example, it turns an LLM answer from: ''' <think> I am thinking hard about if the document is reelevant to the user query on a scale of 0 (irrelevant) to 10 (very relevant). ... </think> <answer>10</answer> ''' into simply: '10' """ mess = ( f"The eval LLM returned an output that can't be parsed as expected: '{output}'" ) # empty if not output.strip(): if env.WDOC_CONTINUE_ON_INVALID_EVAL: logger.warning(mess) return "5" else: raise InvalidDocEvaluationByLLMEval(mess) parsed = thinking_answer_parser(output) logger.debug(f"Eval LLM output: '{output}'") answer = parsed["answer"] answer = answer.replace("-", "") # negative ints are not accepted anyway if not answer.isdigit() and any(li.isdigit() for li in answer.splitlines()): answer = [li for li in answer.splitlines() if li.isdigit()][0] if answer.isdigit(): answer = int(answer) return str(answer) digits = [d for d in re.split(r"\b", parsed["answer"]) if d.isdigit()] # contain no digits if not digits: if env.WDOC_CONTINUE_ON_INVALID_EVAL: logger.warning(mess) return "5" else: raise InvalidDocEvaluationByLLMEval(mess) # good elif len(digits) == 1: return digits[0] else: # ambiguous if env.WDOC_CONTINUE_ON_INVALID_EVAL: logger.warning(mess) return "5" else: raise InvalidDocEvaluationByLLMEval(mess)
[docs] @log_and_time_fn def collate_relevant_intermediate_answers( list_ia: List[str], ) -> str: """rewrite the relevant intermediate answers in a single string to be readable by the combining LLM""" assert list_ia == [ia for ia in list_ia if check_intermediate_answer(ia)], ( "collate_relevant_intermediate_answers should only be receiving relevant answers" ) assert len(list_ia) >= 2, ( f"Cannot collate a single intermediate answer!\n{list_ia[0]}" ) out = "" for ia in list_ia: ia = ia.replace("- • ", "- ").replace("• ", "- ") # occasional bad md out += f"{ia}\n".lstrip() return out
[docs] @log_and_time_fn def semantic_batching( texts: List[str], embedding_engine: Embeddings, ) -> List[List[str]]: """ Given a list of text, embed them, do a hierarchical clutering then sort the list according to the leaf order, then create buckets that best contain each subtopic while keeping a reasonnable number of tokens. This probably helps the LLM to combine the intermediate answers into one. Note that the documents are also sorted inside each batch, so that iterating over each document of each batch in order will follow the optimal leaf order. """ import scipy import pandas as pd import sklearn.decomposition as decomposition import sklearn.metrics as metrics import sklearn.preprocessing as preprocessing assert texts, "No input text received" assert len(texts) > 1, f"received only one text: {texts}" # deduplicate texts temp = [] [temp.append(t) for t in texts if t not in temp] texts = temp if len(texts) <= 3: logger.debug( f"Returned texts in semantic_batching because there were only {len(texts)}" ) return [texts] text_sizes = {t: get_tkn_length(t) for t in texts} itext_sizes = {i: size for i, size in enumerate(texts)} logger.debug(f"Input text sizes in semantic_batching: {itext_sizes}") # get embeddings n_trial = 3 for trial in range(n_trial): try: embeds = np.array(embedding_engine.embed_documents(texts)).squeeze() break except Exception as e: logger.warning( f"Error at trial {trial + 1}/{n_trial} when trying to embed texts for semantic batching: '{e}'" ) if trial + 1 >= n_trial: logger.warning("Too many errors so crashing") raise else: time.sleep(2) n_dim = embeds.shape[1] assert n_dim > 2, ( f"Unexpected number of dimension: {n_dim}, shape was {embeds.shape}" ) max_n_dim = min(100, len(texts)) # optional dimension reduction to gain time try: if n_dim > max_n_dim: scaler = preprocessing.StandardScaler() embed_scaled = scaler.fit_transform(embeds) pca = decomposition.PCA(n_components=max_n_dim) embeds_reduced = pca.fit_transform(embed_scaled) assert embeds_reduced.shape[0] == embeds.shape[0] vr = np.cumsum(pca.explained_variance_ratio_)[-1] if vr <= 0.90: logger.warning( f"Found lower than exepcted PCA explained variance ratio: {vr:.4f}" ) assert vr >= 0.75, ( f"Found substancially low explained variance ratio afer pca at {vr:.4f} so not using dimension reduction" ) embeddings = pd.DataFrame( columns=[f"v_{i}" for i in range(embeds_reduced.shape[1])], index=[i for i in range(len(texts))], data=embeds_reduced, ) except Exception as err: logger.warning( f"Error when doing dimension reduction for semantic batching. Original shape: {embeds.shape}. Error: '{err}'\nContinuing anyway." ) if "embeddings" not in locals(): embeddings = pd.DataFrame( columns=[f"v_{i}" for i in range(embeds.shape[1])], index=[i for i in range(len(texts))], data=embeds, ) # get the pairwise distance matrix pairwise_distances = metrics.pairwise_distances pd_dist = pd.DataFrame( columns=embeddings.index, index=embeddings.index, data=pairwise_distances( embeddings.values, n_jobs=-1, metric="euclidean", ), ) # make sure the intersection is 0 and not a very small float for ind in pd_dist.index: pd_dist.at[ind, ind] = 0 # make sure it's symetric pd_dist = pd_dist.add(pd_dist.T).div(2) # get the hierarchichal semantic sorting order dist: np.ndarray = scipy.spatial.distance.squareform( pd_dist.values ) # convert to condensed format Z: np.ndarray = scipy.cluster.hierarchy.linkage( dist, method="ward", optimal_ordering=True ) assert len(Z.shape) == 2 and Z.shape[1] == 4, f"Unexpected Z shape: {Z.shape}" order: np.typing.NDArray[np.integer] = scipy.cluster.hierarchy.leaves_list(Z) assert len(order.shape) == 1 # TODO: if <= 6 texts we should make 2 or 3 batch just using the order # # this would just return the list of strings in the best order # out_texts = [texts[o] for o in order] # assert len(set(out_texts)) == len(out_texts), "duplicates" # assert len(out_texts) == len(texts), "extra out_texts" # assert not any(o for o in out_texts if o not in texts) # assert not any(t for t in texts if t not in out_texts) # # logger.info(f"Done in {int(time.time()-start)}s") # assert len(texts) == len(out_texts) # get each bucket if we were only looking at the number of texts cluster_trials = {} cluster_mean_tkn = {} for divider in [2, 3, 4, 5, 6]: if divider > len(pd_dist.index): continue cluster_labels = scipy.cluster.hierarchy.fcluster( Z, len(pd_dist.index) // divider, criterion="maxclust" ) labels = np.unique(cluster_labels) labels.sort() if len(labels) == 1: # re cluster if only one label found continue # use heuristics to find the best number of dividers by looking # at the average number of token in each clusters total_mean = 0 for lab in labels: lt = [ texts[int(ind.squeeze())] for ind in np.argwhere(cluster_labels == lab) ] lsize = sum([text_sizes[t] for t in lt]) lmean = lsize / len(lt) total_mean += lmean total_mean /= len(labels) cluster_mean_tkn[divider] = total_mean cluster_trials[divider] = cluster_labels if not cluster_trials: assert len(labels) == 1 logger.warning( f"The clustering algorithm always found the same cluster for the {len(texts)} texts. Assuming the order won't matter." ) return [texts] best_clusters = None for d, ct in cluster_mean_tkn.items(): if ( ct < env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE and ct >= env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE / 2 ): best_clusters = cluster_trials[d] break if best_clusters is None: best_tkns = min(list(cluster_mean_tkn.values())) for d, ct in cluster_mean_tkn.items(): if ct == best_tkns: best_clusters = cluster_trials[d] break assert best_clusters is not None cluster_labels = best_clusters labels = np.unique(cluster_labels) labels.sort() assert len(labels) > 1, cluster_labels # make sure no cluster contains only one text while not all((cluster_labels == lab).sum() > 1 for lab in labels): logger.debug("Remapping clusters.") for lab in labels: if (cluster_labels == lab).sum() == 1: t = texts[np.argmax(cluster_labels == lab)] # the closest is always itself so checking the 2nd closest t_close = (pd_dist.loc[texts.index(t), :]).nsmallest(2).index.tolist() assert texts.index(t) == t_close[0] t_closest = t_close[1] l_closest = cluster_labels[t_closest] if (cluster_labels == l_closest).sum() + 1 == len(texts): # merging small to big would result in only one cluster: # better to even them out assert len(labels) == 2, labels cluster_labels[t_closest] = lab logger.debug(f"Remapped one item from cluster {l_closest} to {lab}") else: # good to go cluster_labels[cluster_labels == lab] = l_closest logger.debug( f"Remapped single item of cluster {lab} to {l_closest}" ) break labels = np.unique(cluster_labels) labels.sort() assert len(labels) > 1, cluster_labels assert all((cluster_labels == lab).sum() > 1 for lab in labels), cluster_labels # Create buckets buckets = [] current_bucket = [] current_tokens = 0 # fill each bucket until reaching max_token for lab in labels: lab_ind = np.argwhere(cluster_labels == lab) assert len(lab_ind) > 1, f"{lab_ind}\n{cluster_labels}" assert len(lab_ind) < len(texts), f"{lab_ind}\n{cluster_labels}" for clustid in lab_ind: text = texts[int(clustid.squeeze())] size = text_sizes[text] if ( current_tokens + size > env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE ) and current_bucket: buckets.append(current_bucket) current_bucket = [text] current_tokens = 0 else: current_bucket.append(text) current_tokens += size assert current_bucket buckets.append(current_bucket) current_bucket = [] current_tokens = 0 assert all(bucket for bucket in buckets), "Empty buckets" # sort each bucket based on the optimal order for ib, b in enumerate(buckets): buckets[ib] = sorted(b, key=lambda t: order[texts.index(t)]) # now if any bucket contains only one text, that means it has too many # tokens itself, so we reequilibrate from the previous buckets while not all(len(b) >= 2 for b in buckets): logger.debug(f"Merging sub buckets. Current len: {len(buckets)}") for ib, b in enumerate(buckets): assert b if len(b) == 1: # figure out which bucket to merge with if ib == 0: # first , merge with next next_id = ib + 1 elif ib + 1 == len(buckets): # last, take the penultimate next_id = ib - 1 elif ib != len( buckets ): # not first nor last, take the neighbour with least minimal distance t_cur = b[0] prev = min( [ pd_dist.loc[texts.index(t_cur), texts.index(t)] for t in buckets[ib - 1] ] ) next = min( [ pd_dist.loc[texts.index(t_cur), texts.index(t)] for t in buckets[ib + 1] ] ) assert prev > 0 and next > 0 if prev < next: next_id = ib - 1 else: next_id = ib + 1 assert buckets[next_id], buckets[next_id] logger.debug(f"Next_id is {next_id}") if len(buckets[next_id]) == 1: # both texts are big, merge them anyway if next_id > ib: buckets[next_id].insert(0, b.pop()) else: buckets[next_id].append(b.pop()) assert not b, b elif ( len(buckets[next_id]) == 2 ): # merging 2:1 -> 1:2 would create a loop if next_id > ib: buckets[next_id].insert(0, b.pop()) else: buckets[next_id].append(b.pop()) assert not b, b else: # send text to the next bucket, at the correct position if next_id > ib: b.append(buckets[next_id].pop(0)) else: b.append(buckets[next_id].pop(-1)) assert id(b) == id(buckets[ib]) break buckets = [b for b in buckets if b] assert all(len(b) >= 2 for b in buckets), ( f"Invalid size of buckets: '{[len(b) for b in buckets]}'" ) unchained = [] [unchained.extend(b) for b in buckets] assert len(unchained) == len(set(unchained)), ( "There were duplicate texts in buckets!" ) assert all(t in texts for t in unchained), "Some text of buckets were added!" assert sorted(unchained) == sorted(texts), ( "There is an issue with semantic_batching" ) logger.debug("Printing size of each bucket in semantic_batching:") for ib, b in enumerate(buckets): sizes = [get_tkn_length(bb) for bb in b] logger.debug(f"{ib}: {sizes}") return buckets
[docs] def pbar_chain( llm: Union[ "langchain_litellm.ChatLiteLLM", "langchain_community.chat_models.fake.FakeListChatModel", ], len_func: str, **tqdm_kwargs, ) -> RunnableLambda: "create a chain that just sets a tqdm progress bar" def actual_pbar_chain( inputs: Union[dict, List], llm: Union[ "langchain_litellm.ChatLiteLLM", "langchain_community.chat_models.fake.FakeListChatModel", ] = llm, ) -> Union[dict, List]: llm.callbacks[0].pbar.append( tqdm( total=eval(len_func), **tqdm_kwargs, ) ) if not llm.callbacks[0].pbar[-1].total: logger.warning(f"Empty total for pbar: {llm.callbacks[0].pbar[-1]}") return inputs actual_pbar_chain = chain(actual_pbar_chain) return actual_pbar_chain
[docs] def pbar_closer( llm: Union[ "langchain_litellm.ChatLiteLLM", "langchain_community.chat_models.fake.FakeListChatModel", ], ) -> RunnableLambda: "close a pbar created by pbar_chain" def actual_pbar_closer( inputs: Union[dict, List], llm: Union[ "langchain_litellm.ChatLiteLLM", "langchain_community.chat_models.fake.FakeListChatModel", ] = llm, ) -> Union[dict, List]: pbar = llm.callbacks[0].pbar[-1] pbar.update(pbar.total - pbar.n) pbar.n = pbar.total pbar.close() return inputs actual_pbar_closer = chain(actual_pbar_closer) return actual_pbar_closer
[docs] @log_and_time_fn def source_replace(input: str, mapping: dict) -> str: """ Replace document identifiers in text with their corresponding numbers. This function substitutes document IDs (like WDOC_1, WDOC_2) with their corresponding document numbers from the mapping dictionary. It processes in reverse order to avoid issues like WDOC_2 replacing part of WDOC_21. Parameters ---------- input : str The text containing document identifiers to replace. mapping : dict Dictionary mapping document IDs to document numbers. Returns ------- str Text with document identifiers replaced by numbers. """ # Make a copy of the input to avoid modifying the original string during iteration result = input # substitute in reverse order to avoid WDOC_2 replacing WDOC_21 doc_ids = list(mapping.keys()) for doc_id in doc_ids[::-1]: doc_num = str(mapping[doc_id]) result = result.replace(doc_id, f"[{doc_num}](#document-{doc_num})") return result
[docs] @log_and_time_fn def autoincrease_top_k( filtered_docs: List[Document], top_k: int, max_top_k: Optional[int] ) -> List[Document]: """ Check if the number of filtered documents suggests top_k should be increased. This function evaluates the ratio of filtered documents to top_k and raises an exception if the ratio is too high (>=0.9), suggesting that more documents should be retrieved. This mechanism allows the query system to automatically increase top_k when it appears that good documents might be getting cut off due to the limit. Parameters ---------- filtered_docs : List[Document] The list of documents that passed the LLM evaluation filtering. top_k : int The current top_k value used for document retrieval. max_top_k : Optional[int] The maximum allowed value for top_k. If None, no automatic increase will be attempted. Returns ------- List[Document] The same list of filtered documents (unchanged). Raises ------ ShouldIncreaseTopKAfterLLMEvalFiltering When the ratio of filtered documents to top_k is >= 0.9 and top_k can still be increased (i.e., top_k < max_top_k). Notes ----- This function is designed to be used in a langchain pipeline where the exception can be caught to retry the query with an increased top_k value. The function logs warnings when the ratio suggests top_k should be increased but max_top_k has been reached. """ if not max_top_k: return filtered_docs ratio = len(filtered_docs) / top_k if ratio >= 0.9: if top_k < max_top_k: mess = ( f"Number of documents found: {len(filtered_docs)}, " f"top_k is {top_k} so ratio={ratio:.1f}, hence " f"top_k should be increased. Max_top_k is {max_top_k}" ) logger.warning(mess) raise ShouldIncreaseTopKAfterLLMEvalFiltering(mess) else: logger.warning( f"Number of documents found: {len(filtered_docs)}, " f"top_k is {top_k} so ratio={ratio:.1f}, hence " f"top_k should be increased but we eached " f"max_top_k ({max_top_k}) so continuing." ) return filtered_docs