Source code for wdoc.utils.tasks.shared_query_search

"""
Shared utilities for query and search tasks.
"""

import asyncio
import copy
from beartype.typing import Callable, List, Tuple

from langchain_community.chat_models.fake import FakeListChatModel
from langchain_core.runnables import chain
from loguru import logger

from wdoc.utils.env import env
from wdoc.utils.misc import log_and_time_fn
from wdoc.utils.tasks.query import parse_eval_output


[docs] @log_and_time_fn def split_query_parts(query: str) -> Tuple[str, str]: """ Split query into parts for embedding search and answering. If the query contains ">>>>", splits it into: - query_for_embedding: part before >>>> - query_to_answer: part after >>>> Otherwise returns the same query for both purposes. Parameters ---------- query : str The input query string Returns ------- Tuple[str, str] A tuple of (query_for_embedding, query_to_answer) Raises ------ AssertionError If query contains more than one occurrence of ">>>>" """ if ">>>>" in query: sp = query.split(">>>>") assert len(sp) == 2, "The query must contain a maximum of 1 occurence of '>>>>'" query_fe = sp[0].strip() query_an = sp[1].strip() else: query_fe, query_an = copy.copy(query), copy.copy(query) return query_fe, query_an
[docs] @log_and_time_fn def create_evaluate_doc_chain( eval_llm, eval_llm_params: List[str], query_eval_check_number: int, eval_cache_wrapper: Callable, prompts, ): """ Create a document evaluation chain for assessing document relevance. This function creates a chain that evaluates documents for relevance to a query using an LLM. It handles different model configurations and caching strategies. Parameters ---------- eval_llm : object The evaluation LLM instance eval_llm_params : List[str] List of supported parameters for the evaluation LLM query_eval_check_number : int Number of evaluation checks to perform eval_cache_wrapper : Callable Function to wrap the evaluation for caching prompts : object Prompts object containing the evaluation prompt Returns ------- chain A langchain chain object for document evaluation """ @eval_cache_wrapper def evaluate_doc_chain( inputs: dict, query_nb: int = query_eval_check_number, eval_model_string: str = eval_llm._get_llm_string(), # just for caching eval_prompt: str = str(prompts.evaluate.to_json()), ) -> List[str]: if isinstance(eval_llm, FakeListChatModel): outputs = ["10" for i in range(query_eval_check_number)] new_p = 0 new_c = 0 new_r = 0 elif "n" in eval_llm_params or query_eval_check_number == 1: def _parse_outputs(out) -> List[str]: reasons = [ gen.generation_info["finish_reason"] for gen in out.generations ] outputs = [gen.text for gen in out.generations] # don't always crash if finish_reason is not stop, because it can sometimes still be parsed. if not all(r == "stop" for r in reasons): logger.warning( f"Unexpected generation finish_reason: '{reasons}' for generations: '{outputs}'. Expected 'stop'" ) assert outputs, "No generations found by query eval llm" # parse_eval_output will crash if the output is bad anyway outputs = [parse_eval_output(o) for o in outputs] return outputs try: out = eval_llm._generate_with_cache( prompts.evaluate.format_messages(**inputs), request_timeout=env.WDOC_LLM_REQUEST_TIMEOUT, ) outputs = _parse_outputs(out) except Exception: # retry without cache logger.debug( "Failed to run eval_llm on an input. Retrying without cache." ) out = eval_llm._generate( prompts.evaluate.format_messages(**inputs), request_timeout=env.WDOC_LLM_REQUEST_TIMEOUT, ) outputs = _parse_outputs(out) if out.llm_output: new_p = out.llm_output["token_usage"]["prompt_tokens"] new_c = out.llm_output["token_usage"]["completion_tokens"] new_r = out.llm_output["token_usage"]["total_tokens"] - new_p - new_c else: new_p = 0 new_c = 0 new_r = 0 else: outputs = [] new_p = 0 new_c = 0 new_r = 0 async def do_eval(subinputs): try: val = await eval_llm._agenerate_with_cache( prompts.evaluate.format_messages(**subinputs), request_timeout=env.WDOC_LLM_REQUEST_TIMEOUT, ) except Exception: # retry without cache val = await eval_llm._agenerate( prompts.evaluate.format_messages(**subinputs), request_timeout=env.WDOC_LLM_REQUEST_TIMEOUT, ) return val outs = [do_eval(inputs) for i in range(query_eval_check_number)] try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) outs = loop.run_until_complete(asyncio.gather(*outs)) for out in outs: assert len(out.generations) == 1, ( f"Query eval llm produced more than 1 evaluations: '{out.generations}'" ) outputs.append(out.generations[0].text) finish_reason = out.generations[0].generation_info["finish_reason"] if finish_reason not in ["stop", "length"]: logger.warning( f"Unexpected finish_reason: '{finish_reason}' for generation '{outputs[-1]}'" ) if out.llm_output: new_p += out.llm_output["token_usage"]["prompt_tokens"] new_c += out.llm_output["token_usage"]["completion_tokens"] new_r += ( out.llm_output["token_usage"]["total_tokens"] - new_p - new_c ) assert outputs, "No generations found by query eval llm" outputs = [parse_eval_output(o) for o in outputs] if len(outputs) < query_eval_check_number and len(outputs) == 1: logger.warning( f"query eval model produced 1 output instead of {query_eval_check_number}). Output: '{outputs}'\nThis is usually because the model is wrongly specified by litellm as having a modifiable `n` parameter. To avoid this use another model or set the query_eval_check_number to 1." ) if "n" in eval_llm_params: eval_llm_params.remove("n") outputs = outputs * query_eval_check_number assert len(outputs) == query_eval_check_number, ( f"Query eval model produced an unexpected number of outputs ({outputs} but expected {query_eval_check_number} outputs).\nInputs: {inputs}'" ) eval_llm.callbacks[0].prompt_tokens += new_p eval_llm.callbacks[0].completion_tokens += new_c eval_llm.callbacks[0].internal_reasoning_tokens += new_r eval_llm.callbacks[0].total_tokens += new_p + new_c + new_r if eval_llm.callbacks[0].pbar: eval_llm.callbacks[0].pbar[-1].update(1) return outputs evaluate_doc_chain = chain(evaluate_doc_chain) return evaluate_doc_chain