"""
Chain (logic) used to query a document.
"""
import re
import time
import numpy as np
from beartype.typing import List, Optional, Union
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import chain
from langchain_core.runnables.base import RunnableLambda
from tqdm.asyncio import tqdm
from loguru import logger
from wdoc.utils.env import env
from wdoc.utils.errors import (
InvalidDocEvaluationByLLMEval,
NoDocumentsAfterLLMEvalFiltering,
NoDocumentsRetrieved,
ShouldIncreaseTopKAfterLLMEvalFiltering,
)
from wdoc.utils.misc import get_tkn_length, thinking_answer_parser, log_and_time_fn
irrelevant_regex = re.compile(r"\bIRRELEVANT\b")
[docs]
@log_and_time_fn
def sieve_documents(instance) -> RunnableLambda:
"""cap the number of retrieved documents as if multiple retrievers are used
we can end up with a lot more document!
"""
def _sieve(inputs: dict) -> dict:
assert "question_to_answer" in inputs, inputs.keys()
assert "unfiltered_docs" in inputs, inputs.keys()
# we have to pass an instance otherwise we can't know if the top_k got updated
assert hasattr(instance, "top_k")
assert hasattr(instance, "max_top_k")
if instance.max_top_k:
assert instance.max_top_k >= instance.top_k
if len(inputs) > instance.top_k:
logger.warning(
"Number of documents found via embeddings was "
f"'{inputs['unfiltered_docs']}' which is > top_k ({instance.top_k}) "
"so we crop"
)
inputs["unfiltered_docs"] = inputs["unfiltered_docs"][: instance.top_k]
return inputs
_sieve = chain(_sieve)
return _sieve
@log_and_time_fn
def refilter_docs(inputs: dict) -> List[Document]:
"filter documents fond via RAG based on the digit answered by the eval llm"
unfiltered_docs = inputs["unfiltered_docs"]
evaluations = inputs["evaluations"]
assert isinstance(unfiltered_docs, list), (
f"unfiltered_docs should be a list, not {type(unfiltered_docs)}"
)
assert isinstance(evaluations, list), (
f"evaluations should be a list, not {type(evaluations)}"
)
assert len(unfiltered_docs) == len(evaluations), (
f"len of unfiltered_docs is {len(unfiltered_docs)} but len of evaluations is {len(evaluations)}"
)
if not unfiltered_docs:
raise NoDocumentsRetrieved("No document corresponding to the query")
filtered_docs = []
for ie, evals in enumerate(evaluations): # iterating over each document
if not isinstance(evals, list):
evals = [evals]
answers = [thinking_answer_parser(ev)["answer"] for ev in evals]
for ia, a in enumerate(answers):
try:
a = int(a)
except Exception as err:
logger.warning(
f"Document was not evaluated with a number: '{err}' for answer '{a}'\nKeeping the document anyway."
)
a = 5
answers[ia] = a
if sum(answers) / len(answers) >= 3:
filtered_docs.append(unfiltered_docs[ie])
if not filtered_docs:
raise NoDocumentsAfterLLMEvalFiltering(
"No document remained after filtering with the query"
)
return filtered_docs
refilter_docs = chain(refilter_docs)
[docs]
@log_and_time_fn
def retrieve_documents_for_query(retriever):
"""
Create a retrieve documents chain for query tasks.
Parameters
----------
retriever : object
The retriever object to use for document retrieval.
Returns
-------
RunnableLambda
A chain that retrieves documents using the provided retriever.
"""
def _retrieve_documents(inputs):
return {
"unfiltered_docs": retriever.invoke(inputs["question_for_embedding"]),
"question_to_answer": inputs["question_to_answer"],
}
_retrieve_documents = chain(_retrieve_documents)
return _retrieve_documents
[docs]
@log_and_time_fn
def parse_eval_output(output: str) -> str:
"""
Parse an LLM's answer about wether a document is relevant or not into an
integer from 0 to 10 as str.
For example, it turns an LLM answer from:
'''
<think>
I am thinking hard about if the document is reelevant to the user query
on a scale of 0 (irrelevant) to 10 (very relevant).
...
</think>
<answer>10</answer>
'''
into simply: '10'
"""
mess = (
f"The eval LLM returned an output that can't be parsed as expected: '{output}'"
)
# empty
if not output.strip():
if env.WDOC_CONTINUE_ON_INVALID_EVAL:
logger.warning(mess)
return "5"
else:
raise InvalidDocEvaluationByLLMEval(mess)
parsed = thinking_answer_parser(output)
logger.debug(f"Eval LLM output: '{output}'")
answer = parsed["answer"]
answer = answer.replace("-", "") # negative ints are not accepted anyway
if not answer.isdigit() and any(li.isdigit() for li in answer.splitlines()):
answer = [li for li in answer.splitlines() if li.isdigit()][0]
if answer.isdigit():
answer = int(answer)
return str(answer)
digits = [d for d in re.split(r"\b", parsed["answer"]) if d.isdigit()]
# contain no digits
if not digits:
if env.WDOC_CONTINUE_ON_INVALID_EVAL:
logger.warning(mess)
return "5"
else:
raise InvalidDocEvaluationByLLMEval(mess)
# good
elif len(digits) == 1:
return digits[0]
else: # ambiguous
if env.WDOC_CONTINUE_ON_INVALID_EVAL:
logger.warning(mess)
return "5"
else:
raise InvalidDocEvaluationByLLMEval(mess)
[docs]
@log_and_time_fn
def semantic_batching(
texts: List[str],
embedding_engine: Embeddings,
) -> List[List[str]]:
"""
Given a list of text, embed them, do a hierarchical clutering then
sort the list according to the leaf order, then create buckets that best
contain each subtopic while keeping a reasonnable number of tokens.
This probably helps the LLM to combine the intermediate answers
into one.
Note that the documents are also sorted inside each batch, so that iterating
over each document of each batch in order will follow the optimal leaf order.
"""
import scipy
import pandas as pd
import sklearn.decomposition as decomposition
import sklearn.metrics as metrics
import sklearn.preprocessing as preprocessing
assert texts, "No input text received"
assert len(texts) > 1, f"received only one text: {texts}"
# deduplicate texts
temp = []
[temp.append(t) for t in texts if t not in temp]
texts = temp
if len(texts) <= 3:
logger.debug(
f"Returned texts in semantic_batching because there were only {len(texts)}"
)
return [texts]
text_sizes = {t: get_tkn_length(t) for t in texts}
itext_sizes = {i: size for i, size in enumerate(texts)}
logger.debug(f"Input text sizes in semantic_batching: {itext_sizes}")
# get embeddings
n_trial = 3
for trial in range(n_trial):
try:
embeds = np.array(embedding_engine.embed_documents(texts)).squeeze()
break
except Exception as e:
logger.warning(
f"Error at trial {trial + 1}/{n_trial} when trying to embed texts for semantic batching: '{e}'"
)
if trial + 1 >= n_trial:
logger.warning("Too many errors so crashing")
raise
else:
time.sleep(2)
n_dim = embeds.shape[1]
assert n_dim > 2, (
f"Unexpected number of dimension: {n_dim}, shape was {embeds.shape}"
)
max_n_dim = min(100, len(texts))
# optional dimension reduction to gain time
try:
if n_dim > max_n_dim:
scaler = preprocessing.StandardScaler()
embed_scaled = scaler.fit_transform(embeds)
pca = decomposition.PCA(n_components=max_n_dim)
embeds_reduced = pca.fit_transform(embed_scaled)
assert embeds_reduced.shape[0] == embeds.shape[0]
vr = np.cumsum(pca.explained_variance_ratio_)[-1]
if vr <= 0.90:
logger.warning(
f"Found lower than exepcted PCA explained variance ratio: {vr:.4f}"
)
assert vr >= 0.75, (
f"Found substancially low explained variance ratio afer pca at {vr:.4f} so not using dimension reduction"
)
embeddings = pd.DataFrame(
columns=[f"v_{i}" for i in range(embeds_reduced.shape[1])],
index=[i for i in range(len(texts))],
data=embeds_reduced,
)
except Exception as err:
logger.warning(
f"Error when doing dimension reduction for semantic batching. Original shape: {embeds.shape}. Error: '{err}'\nContinuing anyway."
)
if "embeddings" not in locals():
embeddings = pd.DataFrame(
columns=[f"v_{i}" for i in range(embeds.shape[1])],
index=[i for i in range(len(texts))],
data=embeds,
)
# get the pairwise distance matrix
pairwise_distances = metrics.pairwise_distances
pd_dist = pd.DataFrame(
columns=embeddings.index,
index=embeddings.index,
data=pairwise_distances(
embeddings.values,
n_jobs=-1,
metric="euclidean",
),
)
# make sure the intersection is 0 and not a very small float
for ind in pd_dist.index:
pd_dist.at[ind, ind] = 0
# make sure it's symetric
pd_dist = pd_dist.add(pd_dist.T).div(2)
# get the hierarchichal semantic sorting order
dist: np.ndarray = scipy.spatial.distance.squareform(
pd_dist.values
) # convert to condensed format
Z: np.ndarray = scipy.cluster.hierarchy.linkage(
dist, method="ward", optimal_ordering=True
)
assert len(Z.shape) == 2 and Z.shape[1] == 4, f"Unexpected Z shape: {Z.shape}"
order: np.typing.NDArray[np.integer] = scipy.cluster.hierarchy.leaves_list(Z)
assert len(order.shape) == 1
# TODO: if <= 6 texts we should make 2 or 3 batch just using the order
# # this would just return the list of strings in the best order
# out_texts = [texts[o] for o in order]
# assert len(set(out_texts)) == len(out_texts), "duplicates"
# assert len(out_texts) == len(texts), "extra out_texts"
# assert not any(o for o in out_texts if o not in texts)
# assert not any(t for t in texts if t not in out_texts)
# # logger.info(f"Done in {int(time.time()-start)}s")
# assert len(texts) == len(out_texts)
# get each bucket if we were only looking at the number of texts
cluster_trials = {}
cluster_mean_tkn = {}
for divider in [2, 3, 4, 5, 6]:
if divider > len(pd_dist.index):
continue
cluster_labels = scipy.cluster.hierarchy.fcluster(
Z, len(pd_dist.index) // divider, criterion="maxclust"
)
labels = np.unique(cluster_labels)
labels.sort()
if len(labels) == 1: # re cluster if only one label found
continue
# use heuristics to find the best number of dividers by looking
# at the average number of token in each clusters
total_mean = 0
for lab in labels:
lt = [
texts[int(ind.squeeze())] for ind in np.argwhere(cluster_labels == lab)
]
lsize = sum([text_sizes[t] for t in lt])
lmean = lsize / len(lt)
total_mean += lmean
total_mean /= len(labels)
cluster_mean_tkn[divider] = total_mean
cluster_trials[divider] = cluster_labels
if not cluster_trials:
assert len(labels) == 1
logger.warning(
f"The clustering algorithm always found the same cluster for the {len(texts)} texts. Assuming the order won't matter."
)
return [texts]
best_clusters = None
for d, ct in cluster_mean_tkn.items():
if (
ct < env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE
and ct >= env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE / 2
):
best_clusters = cluster_trials[d]
break
if best_clusters is None:
best_tkns = min(list(cluster_mean_tkn.values()))
for d, ct in cluster_mean_tkn.items():
if ct == best_tkns:
best_clusters = cluster_trials[d]
break
assert best_clusters is not None
cluster_labels = best_clusters
labels = np.unique(cluster_labels)
labels.sort()
assert len(labels) > 1, cluster_labels
# make sure no cluster contains only one text
while not all((cluster_labels == lab).sum() > 1 for lab in labels):
logger.debug("Remapping clusters.")
for lab in labels:
if (cluster_labels == lab).sum() == 1:
t = texts[np.argmax(cluster_labels == lab)]
# the closest is always itself so checking the 2nd closest
t_close = (pd_dist.loc[texts.index(t), :]).nsmallest(2).index.tolist()
assert texts.index(t) == t_close[0]
t_closest = t_close[1]
l_closest = cluster_labels[t_closest]
if (cluster_labels == l_closest).sum() + 1 == len(texts):
# merging small to big would result in only one cluster:
# better to even them out
assert len(labels) == 2, labels
cluster_labels[t_closest] = lab
logger.debug(f"Remapped one item from cluster {l_closest} to {lab}")
else: # good to go
cluster_labels[cluster_labels == lab] = l_closest
logger.debug(
f"Remapped single item of cluster {lab} to {l_closest}"
)
break
labels = np.unique(cluster_labels)
labels.sort()
assert len(labels) > 1, cluster_labels
assert all((cluster_labels == lab).sum() > 1 for lab in labels), cluster_labels
# Create buckets
buckets = []
current_bucket = []
current_tokens = 0
# fill each bucket until reaching max_token
for lab in labels:
lab_ind = np.argwhere(cluster_labels == lab)
assert len(lab_ind) > 1, f"{lab_ind}\n{cluster_labels}"
assert len(lab_ind) < len(texts), f"{lab_ind}\n{cluster_labels}"
for clustid in lab_ind:
text = texts[int(clustid.squeeze())]
size = text_sizes[text]
if (
current_tokens + size > env.WDOC_SEMANTIC_BATCH_MAX_TOKEN_SIZE
) and current_bucket:
buckets.append(current_bucket)
current_bucket = [text]
current_tokens = 0
else:
current_bucket.append(text)
current_tokens += size
assert current_bucket
buckets.append(current_bucket)
current_bucket = []
current_tokens = 0
assert all(bucket for bucket in buckets), "Empty buckets"
# sort each bucket based on the optimal order
for ib, b in enumerate(buckets):
buckets[ib] = sorted(b, key=lambda t: order[texts.index(t)])
# now if any bucket contains only one text, that means it has too many
# tokens itself, so we reequilibrate from the previous buckets
while not all(len(b) >= 2 for b in buckets):
logger.debug(f"Merging sub buckets. Current len: {len(buckets)}")
for ib, b in enumerate(buckets):
assert b
if len(b) == 1:
# figure out which bucket to merge with
if ib == 0: # first , merge with next
next_id = ib + 1
elif ib + 1 == len(buckets): # last, take the penultimate
next_id = ib - 1
elif ib != len(
buckets
): # not first nor last, take the neighbour with least minimal distance
t_cur = b[0]
prev = min(
[
pd_dist.loc[texts.index(t_cur), texts.index(t)]
for t in buckets[ib - 1]
]
)
next = min(
[
pd_dist.loc[texts.index(t_cur), texts.index(t)]
for t in buckets[ib + 1]
]
)
assert prev > 0 and next > 0
if prev < next:
next_id = ib - 1
else:
next_id = ib + 1
assert buckets[next_id], buckets[next_id]
logger.debug(f"Next_id is {next_id}")
if len(buckets[next_id]) == 1: # both texts are big, merge them anyway
if next_id > ib:
buckets[next_id].insert(0, b.pop())
else:
buckets[next_id].append(b.pop())
assert not b, b
elif (
len(buckets[next_id]) == 2
): # merging 2:1 -> 1:2 would create a loop
if next_id > ib:
buckets[next_id].insert(0, b.pop())
else:
buckets[next_id].append(b.pop())
assert not b, b
else:
# send text to the next bucket, at the correct position
if next_id > ib:
b.append(buckets[next_id].pop(0))
else:
b.append(buckets[next_id].pop(-1))
assert id(b) == id(buckets[ib])
break
buckets = [b for b in buckets if b]
assert all(len(b) >= 2 for b in buckets), (
f"Invalid size of buckets: '{[len(b) for b in buckets]}'"
)
unchained = []
[unchained.extend(b) for b in buckets]
assert len(unchained) == len(set(unchained)), (
"There were duplicate texts in buckets!"
)
assert all(t in texts for t in unchained), "Some text of buckets were added!"
assert sorted(unchained) == sorted(texts), (
"There is an issue with semantic_batching"
)
logger.debug("Printing size of each bucket in semantic_batching:")
for ib, b in enumerate(buckets):
sizes = [get_tkn_length(bb) for bb in b]
logger.debug(f"{ib}: {sizes}")
return buckets
[docs]
def pbar_chain(
llm: Union[
"langchain_litellm.ChatLiteLLM",
"langchain_community.chat_models.fake.FakeListChatModel",
],
len_func: str,
**tqdm_kwargs,
) -> RunnableLambda:
"create a chain that just sets a tqdm progress bar"
def actual_pbar_chain(
inputs: Union[dict, List],
llm: Union[
"langchain_litellm.ChatLiteLLM",
"langchain_community.chat_models.fake.FakeListChatModel",
] = llm,
) -> Union[dict, List]:
llm.callbacks[0].pbar.append(
tqdm(
total=eval(len_func),
**tqdm_kwargs,
)
)
if not llm.callbacks[0].pbar[-1].total:
logger.warning(f"Empty total for pbar: {llm.callbacks[0].pbar[-1]}")
return inputs
actual_pbar_chain = chain(actual_pbar_chain)
return actual_pbar_chain
[docs]
def pbar_closer(
llm: Union[
"langchain_litellm.ChatLiteLLM",
"langchain_community.chat_models.fake.FakeListChatModel",
],
) -> RunnableLambda:
"close a pbar created by pbar_chain"
def actual_pbar_closer(
inputs: Union[dict, List],
llm: Union[
"langchain_litellm.ChatLiteLLM",
"langchain_community.chat_models.fake.FakeListChatModel",
] = llm,
) -> Union[dict, List]:
pbar = llm.callbacks[0].pbar[-1]
pbar.update(pbar.total - pbar.n)
pbar.n = pbar.total
pbar.close()
return inputs
actual_pbar_closer = chain(actual_pbar_closer)
return actual_pbar_closer
[docs]
@log_and_time_fn
def source_replace(input: str, mapping: dict) -> str:
"""
Replace document identifiers in text with their corresponding numbers.
This function substitutes document IDs (like WDOC_1, WDOC_2) with their
corresponding document numbers from the mapping dictionary. It processes
in reverse order to avoid issues like WDOC_2 replacing part of WDOC_21.
Parameters
----------
input : str
The text containing document identifiers to replace.
mapping : dict
Dictionary mapping document IDs to document numbers.
Returns
-------
str
Text with document identifiers replaced by numbers.
"""
# Make a copy of the input to avoid modifying the original string during iteration
result = input
# substitute in reverse order to avoid WDOC_2 replacing WDOC_21
doc_ids = list(mapping.keys())
for doc_id in doc_ids[::-1]:
doc_num = str(mapping[doc_id])
result = result.replace(doc_id, f"[{doc_num}](#document-{doc_num})")
return result
[docs]
@log_and_time_fn
def autoincrease_top_k(
filtered_docs: List[Document], top_k: int, max_top_k: Optional[int]
) -> List[Document]:
"""
Check if the number of filtered documents suggests top_k should be increased.
This function evaluates the ratio of filtered documents to top_k and raises
an exception if the ratio is too high (>=0.9), suggesting that more documents
should be retrieved. This mechanism allows the query system to automatically
increase top_k when it appears that good documents might be getting cut off
due to the limit.
Parameters
----------
filtered_docs : List[Document]
The list of documents that passed the LLM evaluation filtering.
top_k : int
The current top_k value used for document retrieval.
max_top_k : Optional[int]
The maximum allowed value for top_k. If None, no automatic increase
will be attempted.
Returns
-------
List[Document]
The same list of filtered documents (unchanged).
Raises
------
ShouldIncreaseTopKAfterLLMEvalFiltering
When the ratio of filtered documents to top_k is >= 0.9 and top_k
can still be increased (i.e., top_k < max_top_k).
Notes
-----
This function is designed to be used in a langchain pipeline where the
exception can be caught to retry the query with an increased top_k value.
The function logs warnings when the ratio suggests top_k should be increased
but max_top_k has been reached.
"""
if not max_top_k:
return filtered_docs
ratio = len(filtered_docs) / top_k
if ratio >= 0.9:
if top_k < max_top_k:
mess = (
f"Number of documents found: {len(filtered_docs)}, "
f"top_k is {top_k} so ratio={ratio:.1f}, hence "
f"top_k should be increased. Max_top_k is {max_top_k}"
)
logger.warning(mess)
raise ShouldIncreaseTopKAfterLLMEvalFiltering(mess)
else:
logger.warning(
f"Number of documents found: {len(filtered_docs)}, "
f"top_k is {top_k} so ratio={ratio:.1f}, hence "
f"top_k should be increased but we eached "
f"max_top_k ({max_top_k}) so continuing."
)
return filtered_docs