Source code for wdoc.utils.misc

"""
Miscellanous functions etc.
"""

from blake3 import blake3
import re
import inspect
import json
import os
import sys
import uuid
import warnings
from dataclasses import dataclass, field
from datetime import timedelta
from functools import cache as memoize
from functools import partial, wraps
from pathlib import Path

from beartype.door import is_bearable
from beartype.typing import (
    Dict,
    Callable,
    List,
    Literal,
    Union,
    get_type_hints,
    Optional,
    Any,
)
from joblib import Memory
from joblib import hash as jhash
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.runnables import chain
from platformdirs import user_cache_dir
from loguru import logger

from wdoc.utils.env import env, is_input_piped, pytest_ongoing
from wdoc.utils.errors import UnexpectedDocDictArgument
from wdoc.utils.tasks.types import wdocTask

import lazy_import

litellm = lazy_import.lazy_module("litellm")
chonkie = lazy_import.lazy_module("chonkie")


# ignore warnings from beautiful soup that can happen because anki is not exactly html
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    module="bs4",
    message=".*The input looks more like a filename than markup.*",
)

# additional warnings to ignore
warnings.filterwarnings(
    "ignore", module="litellm", message=".*Counting tokens for OpenAI model=.*"
)
warnings.filterwarnings("ignore", module="httpx", message="Use 'content=.*")

try:
    import ftlangdetect

    def language_detector(text: str) -> float:
        try:
            return ftlangdetect.detect(text.lower())["score"]
        except Exception as e:
            logger.info(
                f"Error when running ftlangdetect: '{e}'. First 100 chars of the str were '{text[:100]}'. Assuming probability of 1."
            )
            return 1.0

    assert isinstance(language_detector("This is a test"), float)
except Exception as err:
    if env.WDOC_VERBOSE:
        logger.warning(
            f"Couldn't import optional package 'ftlangdetect' from 'fasttext-langdetect', trying to import langdetect (but it's much slower): '{err}'"
        )
    if "ftlangdetect" in sys.modules:
        del sys.modules["ftlangdetect"]

    try:
        import langdetect

[docs] def language_detector(text: str) -> float: try: return langdetect.detect_langs(text.lower())[0].prob except Exception as e: logger.info( f"Error when running langdetect: '{e}'. First 100 chars of the str were '{text[:100]}'. Assuming probability of 1." ) return 1.0
assert isinstance(language_detector("This is a test"), float) except Exception as err: if env.WDOC_VERBOSE: logger.warning( f"Couldn't import optional package 'langdetect' either: '{err}'" ) language_detector = None if ( "OVERRIDE_USER_DIR_PYTEST_WDOC" in os.environ and os.environ["OVERRIDE_USER_DIR_PYTEST_WDOC"] == "true" ): assert pytest_ongoing, ( "Detected env var OVERRIDE_USER_DIR_PYTEST_WDOC but not detecting a pytest environment!" ) cache_dir = Path.cwd() / "wdoc_user_cache_dir" if cache_dir.exists(): logger.debug( f"PYTEST detected so using cache_dir '{cache_dir.absolute()}' (already exists)" ) else: logger.debug( f"PYTEST detected so using cache_dir '{cache_dir.absolute()}' (does not exists)" ) else: cache_dir = Path(user_cache_dir(appname="wdoc")) cache_dir.mkdir(parents=True, exist_ok=True) doc_loaders_cache_dir = cache_dir / "doc_loaders" doc_loaders_cache_dir.mkdir(exist_ok=True) doc_loaders_cache = Memory(doc_loaders_cache_dir, verbose=0) hashdoc_cache_dir = cache_dir / "doc_hashing" hashdoc_cache_dir.mkdir(exist_ok=True) hashdoc_cache = Memory(hashdoc_cache_dir, verbose=0) (cache_dir / "query_eval_llm").mkdir(exist_ok=True) query_eval_cache = Memory(cache_dir / "query_eval_llm", verbose=0) # remove cache files older than X days if env.WDOC_EXPIRE_CACHE_DAYS: doc_loaders_cache.reduce_size( age_limit=timedelta(days=int(env.WDOC_EXPIRE_CACHE_DAYS)) ) hashdoc_cache.reduce_size(age_limit=timedelta(days=int(env.WDOC_EXPIRE_CACHE_DAYS))) query_eval_cache.reduce_size( age_limit=timedelta(days=int(env.WDOC_EXPIRE_CACHE_DAYS)) ) # for reading length estimation wpm = 250 average_word_length = 6 # separators used for the text splitter recur_separator = ["\n\n\n\n", "\n\n\n", "\n\n", "\n", "...", ".", " ", ""] min_token = 20 max_token = 10_000_000 min_lang_prob = 0.50 printed_unexpected_api_keys = [False] # to print it only once # loader specific arguments filetype_arg_types = { "pdf_parsers": Union[str, List[str]], "anki_deck": str, "anki_notetype": str, "anki_profile": str, "anki_template": str, "anki_tag_filter": str, "anki_tag_render_filter": str, "json_dict_template": str, "json_dict_exclude_keys": List, "audio_backend": Literal["whisper", "deepgram"], "audio_unsilence": bool, "whisper_lang": str, "whisper_prompt": str, "deepgram_kwargs": dict, "youtube_language": str, "youtube_translation": str, "youtube_audio_backend": Literal["youtube", "whisper", "deepgram"], "load_functions": List, "doccheck_min_token": int, "doccheck_max_token": int, "doccheck_min_lang_prob": float, "online_media_url_regex": str, "online_media_resourcetype_regex": str, "loading_failure": Literal["crash", "warn"], "ddg_max_results": int, "ddg_region": str, "ddg_safesearch": Literal["on", "off", "moderate"], } # extra arguments supported when instanciating wdoc extra_args_types = { "path": Union[str, Path], "embed_instruct": str, "include": str, "exclude": str, "filter_content": Union[List[str], str], "filter_metadata": Union[List[str], str], "source_tag": str, "pattern": str, "recursed_filetype": str, } extra_args_types.update(filetype_arg_types) class DocDict(dict): """like dictionnaries but only allows keys that can be used when loading a document. Also checks the value type. The environnment variable 'WDOC_STRICT_DOCDICT' is a default value at instanciation time. Depending on WDOC_STRICT_DOCDICT (if not passed manually): if True: crash if unexpected arg if False: print in red if unexpected arg but add anyway if "strip": print in red but don't add """ allowed_keys: set = set( sorted( [ "path", "filetype", "file_hash", "source_tag", "recur_parent_id", "subitem_link", ] + list(filetype_arg_types.keys()) ) ) allowed_types: dict = filetype_arg_types __strict__ = env.WDOC_STRICT_DOCDICT def __hash__(self): "make it hashable, to check for duplicates" keys = sorted(self.keys()) as_string = "" for k in keys: as_string += "\n" as_string += str(k) try: as_string += jhash(self[k]) except Exception: as_string += str(self[k]) return hash(as_string) def __check_values__(self, key, value, strict) -> bool: if key not in self.allowed_keys: mess = ( f"Cannot set key '{key}' in a DocDict. Allowed keys are:\n-" + "\n-".join(sorted(self.allowed_keys)) + "\nYou can use the env " "variable WDOC_STRICT_DOCDICT to avoid this issue." ) if strict is True: raise UnexpectedDocDictArgument(mess) elif strict is False: logger.warning(mess) return True elif strict == "strip": logger.warning(mess) return False else: raise ValueError(strict) elif ( (key in self.allowed_types) and (value is not None) and (not is_bearable(value, self.allowed_types[key])) ): mess = ( f"Type of key {key} should be {self.allowed_types[key]}," f"not {type(value)}." "\nYou can use the env " "variable WDOC_STRICT_DOCDICT to avoid this issue." ) if strict is True: raise UnexpectedDocDictArgument(mess) elif strict is False: logger.warning(mess) return True elif strict == "strip": logger.warning(mess) return False else: raise ValueError(strict) return True def __init__(self, docdict: dict, strict=env.WDOC_STRICT_DOCDICT) -> None: assert docdict, "Can't give an empty docdict as argument" assert strict in [True, False, "strip"], "Unexpected strict value" ignore_kwargs = [] for k, v in docdict.items(): if not self.__check_values__(k, v, strict): ignore_kwargs.append(k) for ik in ignore_kwargs: if ik in docdict: del docdict[ik] if strict != "strip": assert docdict, "Can't create DocDict: no args nor kwargs after filtering!" super().__init__(docdict) self.__strict__ = strict def __setitem__(self, key, value) -> None: assert self.__strict__ in [True, False, "strip"], "Unexpected strict value" self.__check_values__(key, value, self.__strict__) super().__setitem__(key, value)
[docs] def optional_strip_unexp_args(func: Callable) -> Callable: """if the environment variable WDOC_STRICT_DOCDICT is set to 'true' then this automatically removes any unexpected argument before calling a loader function for a specific filetype.""" if not env.WDOC_STRICT_DOCDICT: return func else: # find the true function, otherwise func can be a decorated truefunc and might forget the annotations. if hasattr(func, "func"): truefunc = func.func else: truefunc = func while hasattr(truefunc, "func"): truefunc = truefunc.func @wraps(truefunc) def wrapper(*args, **kwargs): assert not args, ( f"We are not expecting args here, only kwargs. Received {args}" ) sig = inspect.signature(truefunc) bound_args = sig.bind_partial(**kwargs) # Remove unexpected positional arguments bound_args.arguments = { k: v for k, v in bound_args.arguments.items() if k in sig.parameters } # Remove unexpected keyword arguments kwargs2 = {k: v for k, v in kwargs.items() if k in sig.parameters} diffkwargs = {k: v for k, v in kwargs.items() if k not in kwargs2} if diffkwargs: mess = f"Unexpected args or kwargs in func {func}:" for kwarg in diffkwargs: mess += f"\n-KWARG: {kwarg}" logger.warning(mess) assert kwargs2, ( f"No kwargs2 found for func {func}. There's probably an issue with the decorator" ) return func(**kwargs2) return wrapper
def hasher(text: str) -> str: """used to hash the text contant of each doc to cache the splitting and embeddings""" return blake3(text.encode()).hexdigest()[:20] def file_hasher(doc: dict) -> str: """used to hash a file's content, as describe by a dict A caching mechanism is used to avoid recomputing hash of file that have the same path and metadata. If the doc dict does not contain a path, the hash of the dict will be returned. """ if "path" not in doc: return hasher(json.dumps(doc, ensure_ascii=False)) hashable = False if "path" in doc and doc["path"] and Path(doc["path"]).exists(): hashable = True if isinstance(doc["path"], str): if doc["path"] == "" or (not doc["path"].strip()): hashable = False if not doc["path"]: hashable = False if not isinstance(doc["path"], (str, Path)): hashable = False if hashable: file = Path(doc["path"]) stats = file.stat() return _file_hasher( abs_path=str(file.resolve().absolute()), stats=[stats.st_mtime, stats.st_ctime, stats.st_ino, stats.st_size], ) else: return hasher(json.dumps(doc, ensure_ascii=False)) @hashdoc_cache.cache def _file_hasher(abs_path: str, stats: List[Union[int, float]]) -> str: with open(abs_path, "rb") as f: return blake3(f.read()).hexdigest()[:20]
[docs] def html_to_text(html: str, remove_image: bool = False) -> str: """used to strip any html present in the text files""" import bs4 html = html.replace("</li><li>", "<br>") # otherwise they might get joined html = html.replace("</ul><ul>", "<br>") # otherwise they might get joined html = html.replace("<br>", "\n").replace( "</br>", "\n" ) # otherwise newlines are lost soup = bs4.BeautifulSoup(html, "html.parser") content = [] for element in soup.descendants: if element.name == "img" and (not remove_image): element = str(element) if element in html: content.append(element) elif element[:-2] + ">" in html: content.append(element[:-2] + ">") else: if env.WDOC_VERBOSE: temptext = " ".join(filter(None, content)) logger.warning( f"Image not properly parsed from bs4:\n{element}\n{temptext}" ) elif isinstance(element, bs4.NavigableString): content.append(str(element).strip()) text = " ".join(filter(None, content)) while "\n\n" in text: text = text.replace("\n\n", "\n") if "<img" in text and remove_image: logger.warning(f"Failed to remove <img from anki card: {text}") return text
def debug_chain(inputs: Union[dict, List]) -> Union[dict, List]: "use it between | pipes | in a chain to open the debugger" if hasattr(inputs, "keys"): logger.warning(str(inputs.keys())) breakpoint() return inputs debug_chain = chain(debug_chain)
[docs] def wrapped_model_name_matcher(model: str) -> str: "find the best match for a modelname (wrapped to make some check)" # find the currently set api keys to avoid matching models from # unset providers all_backends = list(litellm.models_by_provider.keys()) backends = [] for k, v in dict(os.environ).items(): if k.endswith("_API_KEY"): backend = k.split("_API_KEY")[0].lower() if ( backend not in all_backends and env.WDOC_VERBOSE and not printed_unexpected_api_keys[0] ): logger.debug( f"Found API_KEY for backend {backend} that is not a known backend for litellm." ) else: backends.append(backend) if env.WDOC_VERBOE: printed_unexpected_api_keys[0] = True assert backends, "No API keys found in environnment" # filter by providers backend, modelname = model.split("/", 1) if backend not in all_backends: raise Exception( f"Model {model} with backend {backend}: backend not found in " "litellm.\nList of litellm providers/backend:\n" f"{all_backends}" ) if backend not in backends: raise Exception( f"Trying to use backend {backend} but no API KEY was found for it in the environnment." ) candidates = litellm.models_by_provider[backend] if modelname in candidates: return model subcandidates = [m for m in candidates if m.startswith(modelname)] if len(subcandidates) == 1: good = f"{backend}/{subcandidates[0]}" return good from difflib import get_close_matches match = get_close_matches(modelname, candidates, n=1) if match: return match[0] else: logger.warning( f"Couldn't match the modelname {model} to any known model. " "Continuing but this will probably crash wdoc further " "down the code." ) return model
@memoize def model_name_matcher(model: str) -> str: """find the best match for a modelname (wrapper that checks if the matched model has a known cost and print the matched name) Bypassed if env variable WDOC_NO_MODELNAME_MATCHING is 'true' """ assert "testing" not in model.lower(), ( "Found 'testing' in model, this should not happen" ) assert "/" in model, f"expected / in model '{model}'" if env.WDOC_NO_MODELNAME_MATCHING: # logger.debug(f"Bypassing model name matching for model '{model}'") return model out = wrapped_model_name_matcher(model) if out != model and env.WDOC_VERBOSE: logger.debug(f"Matched model name {model} to {out}") assert out in litellm.model_cost or out.split("/", 1)[1] in litellm.model_cost, ( f"Neither {out} nor {out.split('/', 1)[1]} found in litellm.model_cost" ) return out
[docs] @memoize def get_openrouter_metadata() -> dict: """fetch the metadata from openrouter, because litellm takes always too much time to add new models.""" import requests url = "https://openrouter.ai/api/v1/models" response = requests.get(url) rep = response.json() # put it in a suitable format data = {} for info in rep["data"]: modelid = "openrouter/" + info["id"] assert modelid not in data, modelid del info["id"] pricing = info["pricing"] # fix pricing is a str originally for k, v in pricing.items(): pricing[k] = float(v) data[modelid] = info # for models that for example end with ":free", make them appear # under their full name too while ":" in modelid: modelid = modelid[::-1].split(":")[0][::-1] if modelid not in data: data[modelid] = info # Example of output: # {'id': 'microsoft/phi-4-reasoning-plus:free', # 'name': 'Microsoft: Phi 4 Reasoning Plus (free)', # 'created': 1746130961, # 'description': REMOVED # 'context_length': 32768, # 'architecture': {'modality': 'text->text', # 'input_modalities': ['text'], # 'output_modalities': ['text'], # 'tokenizer': 'Other', # 'instruct_type': None}, # 'pricing': {'prompt': '0', # 'completion': '0', # 'request': '0', # 'image': '0', # 'web_search': '0', # 'internal_reasoning': '0'}, # 'top_provider': {'context_length': 32768, # 'max_completion_tokens': None, # 'is_moderated': False}, # 'per_request_limits': None, # 'supported_parameters': ['max_tokens', # 'temperature', # 'top_p', # 'reasoning', # 'include_reasoning', # 'stop', # 'frequency_penalty', # 'presence_penalty', # 'seed', # 'top_k', # 'min_p', # 'repetition_penalty', # 'logprobs', # 'logit_bias', # 'top_logprobs']} return data
@dataclass class ModelName: "Simply stores the different way to phrase a model name" original: str backend: str = field(init=False) model: str = field(init=False) sanitized: str = field(init=False) def __post_init__(self): assert "/" in self.original, ( f"Modelname must contain a / to distinguish the backend from the model. Received '{self.original}'" ) self.backend, self.model = self.original.split("/", 1) self.backend = self.backend.lower() # Use a sanitized name for the cache path self.sanitized = self.original if "/" in self.model: try: if Path(self.model).exists(): with open( Path(self.model).resolve().absolute().__str__(), "rb" ) as f: h = blake3(f.read() + str(self.model)).hexdigest()[:15] self.sanitized = Path(self.model).name + "_" + h except Exception: pass self.sanitized = self.sanitized.replace("/", "_") if env.WDOC_PRIVATE_MODE: self.sanitized = "private_" + self.sanitized def is_testing(self) -> bool: "Return True if the model is 'testing/testing'." if "testing" in self.original.lower(): return True return False def __hash__(self): "necessary for memoizing" return (str(self.original.__hash__()) + str("ModelName".__hash__())).__hash__() @memoize def get_model_price(model: ModelName) -> Dict[str, Union[float, int]]: assert "cli_parser" not in model.backend, ( f"Found a cli_parser model backend, this should not happen. Model if: '{model}'" ) if env.WDOC_ALLOW_NO_PRICE: logger.warning( f"Disabling price computation for {model} because env var 'WDOC_ALLOW_NO_PRICE' is 'true'" ) return {"prompt": 0, "completion": 0, "internal_reasoning": 0} if model.backend == "ollama": return {"prompt": 0, "completion": 0, "internal_reasoning": 0} elif model.is_testing(): return {"prompt": 0, "completion": 0, "internal_reasoning": 0} elif model.backend == "openrouter": metadata = get_openrouter_metadata() assert model.original in metadata, f"Missing {model} from openrouter" pricing = metadata[model.original]["pricing"] if "request" in pricing and pricing["request"]: logger.error( f"Found non 0 request for {model}, this is not supported by wdoc so the price will not be accurate" ) if not "internal_reasoning" in pricing: logger.warning( f"Warning: no 'internal_reasoning' price found for model '{model.original}'. Setting it to 0 but this might make price estimation wrong. Detected pricing: '{pricing}'" ) pricing["internal_reasoning"] = 0 return pricing for key in ["original", "model", "sanitized"]: mod = getattr(model, key) if mod in litellm.model_cost: pricing = litellm.model_cost[mod] output = {} output["prompt"] = pricing["input_cost_per_token"] output["completion"] = pricing["output_cost_per_token"] if "output_cost_per_reasoning_token" in pricing: output["internal_reasoning"] = pricing[ "output_cost_per_reasoning_token" ] else: output["internal_reasoning"] = 0 for k, v in pricing.items(): if k not in output: output[k] = v if not "internal_reasoning" in output: logger.warning( f"Warning: no 'internal_reasoning' price found for model '{mod}'. Setting it to 0 but this might make price estimation wrong. Detected pricing: '{output}'" ) output["internal_reasoning"] = 0 return output raise Exception( f"Can't find the price of '{model}'\nUpdate litellm or set WDOC_ALLOW_NO_PRICE=True if you still want to use this model." ) @memoize def get_model_max_tokens(modelname: ModelName) -> int: if modelname.backend == "openrouter": openrouter_data = get_openrouter_metadata() assert modelname.original in openrouter_data, ( f"Missing model {modelname.original} from openrouter metadata" ) return openrouter_data[modelname.original]["context_length"] if modelname.original in litellm.model_cost: return litellm.model_cost[modelname.original]["max_tokens"] elif (trial := modelname.model) in litellm.model_cost: return litellm.model_cost[trial]["max_tokens"] elif (trial2 := modelname.model.split("/")[-1]) in litellm.model_cost: return litellm.model_cost[trial2]["max_tokens"] else: hailmary = [ at for at in dir(modelname) if not (at.startswith("_") or at.endswith("_")) ] for trial3 in hailmary: try: return litellm.get_model_info(getattr(modelname, trial3))["max_tokens"] except Exception: if trial3 == hailmary[-1]: raise def get_tkn_length( tosplit: Union[str, Document], modelname: Union[str, ModelName] = "gpt-4o-mini", ) -> int: if isinstance(modelname, ModelName): modelname = modelname.original modelname = modelname.replace("openrouter/", "") if isinstance(tosplit, str): return litellm.token_counter(model=modelname, text=tosplit) # avoid recomputing token length for documents tl = tosplit.metadata.get("tkn_length", None) if tl and isinstance(tl, int): return tl else: tl = litellm.token_counter(model=modelname, text=tosplit.page_content) tosplit.metadata["tkn_length"] = tl return tl
[docs] class ChonkieSemanticSplitter(TextSplitter): """ Text splitter using chonkie's semantic chunker. This splitter uses semantic boundaries from chonkie to create meaningful chunks, then merges them to reach the desired token count while respecting overlap. The semantic chunking is memoized for efficiency. """
[docs] def __init__( self, chunk_size: int, chunk_overlap: int, length_function: Callable[[str], int], ): """ Initialize the semantic splitter. Parameters ---------- chunk_size : int Maximum number of tokens per chunk. chunk_overlap : int Number of tokens to overlap between chunks. length_function : Callable[[str], int] Function to compute token length of text. """ super().__init__() self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function
@staticmethod @memoize def _get_semantic_units(text: str, model_name: str) -> tuple: """ Get semantic units from chonkie, memoized for efficiency. Parameters ---------- text : str Text to chunk semantically. model_name : str Model name for the chunker (used for cache key). Returns ------- tuple Tuple of semantic unit strings (tuple for hashability). """ from chonkie import SemanticChunker chunker = SemanticChunker(model_name=model_name) chunks = chunker.chunk(text) # Convert to tuple of strings for hashability and caching # Handle both string chunks and chunk objects with .text attribute return tuple( str(chunk.text) if hasattr(chunk, "text") else str(chunk) for chunk in chunks )
[docs] def split_text(self, text: str) -> List[str]: """ Split text into chunks using semantic boundaries and token limits. Semantic units from chonkie are merged until reaching chunk_size, with overlap handling between chunks. Parameters ---------- text : str Text to split. Returns ------- List[str] List of text chunks. """ # Get semantic units from chonkie (memoized) semantic_units = list( self._get_semantic_units(text, "minishlab/potion-multilingual-128M") ) if not semantic_units: return [] # Merge semantic units until reaching chunk_size chunks = [] current_chunk = [] current_length = 0 for unit in semantic_units: unit_length = self._length_function(unit) # Check if adding this unit would exceed chunk_size if current_length + unit_length > self._chunk_size and current_chunk: # Save current chunk chunks.append(" ".join(current_chunk)) # Handle overlap by keeping last few units if self._chunk_overlap > 0: overlap_units = [] overlap_length = 0 for prev_unit in reversed(current_chunk): prev_length = self._length_function(prev_unit) if overlap_length + prev_length <= self._chunk_overlap: overlap_units.insert(0, prev_unit) overlap_length += prev_length else: break current_chunk = overlap_units current_length = overlap_length else: current_chunk = [] current_length = 0 current_chunk.append(unit) current_length += unit_length # Add final chunk if any if current_chunk: chunks.append(" ".join(current_chunk)) return chunks
[docs] def transform_documents(self, documents: List[Document]) -> List[Document]: """ Transform documents by splitting them into chunks. This method splits each document's content using semantic boundaries and creates new Document objects for each chunk, preserving the original metadata. Parameters ---------- documents : List[Document] List of documents to transform. Returns ------- List[Document] List of transformed document chunks. """ transformed_docs = [] for doc in documents: chunks = self.split_text(doc.page_content) for chunk in chunks: transformed_docs.append( Document(page_content=chunk, metadata=doc.metadata.copy()) ) return transformed_docs
text_splitters = {} DEFAULT_SPLITTER_MODELNAME = ModelName("openai/gpt-4o-mini")
[docs] def get_splitter( task: wdocTask, modelname: ModelName = DEFAULT_SPLITTER_MODELNAME, ) -> "TextSplitter": "we don't use the same text splitter depending on the task" # avoid creating many times this object if task not in text_splitters: text_splitters[task] = {} if modelname.original in text_splitters[task]: return text_splitters[task][modelname.original] # if task is parse but we let the model as testing: assume we want a single super large document with no splitting if task.parse and modelname.original == "cliparser/cliparser": return RecursiveCharacterTextSplitter( separators=recur_separator, chunk_size=1e7, chunk_overlap=0, length_function=get_tkn_length, ) if modelname.is_testing(): return get_splitter(task=task, modelname=DEFAULT_SPLITTER_MODELNAME) try: if modelname.model == "gpt-4o-mini": # this is not the true limit of 4o-mini but a good placeholder for if we are using the default model anyway, see get_tkn_length above max_tokens = 4096 else: max_tokens = get_model_max_tokens(modelname) except Exception as err: max_tokens = 4096 logger.warning( f"Failed to get max_tokens limit for model {modelname.original}: '{err}'" ) # Cap context sizes if (task.query or task.search) and max_tokens > env.WDOC_MAX_EMBED_CONTEXT: logger.warning( f"Capping max_tokens for model {modelname} to WDOC_MAX_EMBED_CONTEXT ({env.WDOC_MAX_EMBED_CONTEXT} instead of {max_tokens}) because in query mode and we can only guess the context size of the embedding model." ) max_tokens = min(max_tokens, env.WDOC_MAX_EMBED_CONTEXT) if max_tokens > env.WDOC_MAX_CHUNK_SIZE: logger.debug( f"Capping max_tokens for model {modelname} to the WDOC_MAX_CHUNK_SIZE value ({env.WDOC_MAX_CHUNK_SIZE} instead of {max_tokens})." ) max_tokens = min(max_tokens, env.WDOC_MAX_CHUNK_SIZE) model_tkn_length = partial(get_tkn_length, modelname=modelname.original) if task.query or task.search or task.parse: text_splitter = ChonkieSemanticSplitter( chunk_size=int(3 / 4 * max_tokens), # default 4000 chunk_overlap=500, # default 200 length_function=model_tkn_length, ) elif task.summarize: text_splitter = ChonkieSemanticSplitter( chunk_size=int(1 / 2 * max_tokens), chunk_overlap=500, length_function=model_tkn_length, ) else: raise Exception(task) text_splitters[task][modelname.original] = text_splitter return text_splitter
[docs] def check_docs_tkn_length( docs: List[Document], identifier: Any, min_token: int = min_token, max_token: int = max_token, min_lang_prob: float = min_lang_prob, check_language: bool = False, ) -> float: """checks that the number of tokens in the document is high enough, not too low, and has a high enough language probability, otherwise something probably went wrong.""" identifier = str(identifier) assert docs, f"Received empty doc, identifier was: '{identifier}'" size = sum([get_tkn_length(d) for d in docs]) if size <= min_token: logger.warning( f"Example of page from document with too few tokens : {docs[len(docs) // 2].page_content}" ) raise Exception( f"The number of token from '{identifier}' is {size} <= {min_token}, probably something went wrong?" ) if size >= max_token: logger.warning( f"Example of page from document with too many tokens : {docs[len(docs) // 2].page_content}" ) raise Exception( f"The number of token from '{identifier}' is {size} >= {max_token}, probably something went wrong?" ) if check_language is False: return 1.0 # check if language check is above a threshold and cast as lowercase as it's apparently what it was trained on try: if not language_detector: # bypass if language_detector not defined return 1.0 probs = [language_detector(d.page_content.replace("\n", "<br>")) for d in docs] if not probs or probs[0] is None: # bypass if language_detector not defined return 1.0 prob = sum(probs) / len(probs) if prob <= min_lang_prob: raise Exception( f"Low language probability for {identifier}: prob={prob:.3f}<{min_lang_prob}.\nExample page: {docs[len(docs) // 2]}" ) except Exception as err: if str(err).startswith("Low language probability"): raise if "no features in text" in str(err).lower(): logger.exception( f"language_detector couldn't find text features of text '{identifier}'. Treating it as valid document." ) return 1.0 else: logger.exception( f"Error when using language_detector on '{identifier}': {err}. Treating it as valid document." ) return 1.0 return prob
def unlazyload_modules(): """make sure no modules are lazy loaded. Useful when we wan't to make sure not to loose time and that everything works smoothly. For example who knows what happens when multiprocessing with lazy loaded modules.""" if env.WDOC_IMPORT_TYPE not in ["both", "lazy"]: logger.debug("Lazyloading is disabled so not unlazyloading modules.") return while True: found_one = False for k, v in sys.modules.items(): try: str(v) except Exception as e: logger.warning( f"Very weird error when loading a package, consider setting WDOC_IMPORT_TYPE to another value than '{env.WDOC_IMPORT_TYPE}'. Error message was '{e}'" ) if "Lazily-loaded" in str(v): try: dir(v) # this is enough to trigger the loading found_one = True except Exception as err: raise Exception( f"Error when unlazyloading module '{k}'. Error: '{err}'" "\nThis can be caused by beartype's typechecking" "\nYou can also try setting the env variable " "WDOC_IMPORT_TYPE to 'native' or 'thread'" ) from err break # otherwise dict size change during iteration assert "Lazily-loaded" not in str(v) if found_one: continue else: break def disable_internet(allowed: dict) -> None: """ To be extra sure that no connection goes out of the computer when --private is used, we overload the socket module to make it only able to reach local connection. """ import socket logger.warning( "Disabling outgoing internet because private mode is on. " "The only allowed IPs from now on are the ones from the " "argument llm_api_bases. Note that this permanently filters " "outgoing python connections so might interfere with other " "python programs is you are importing wdoc instead " "of calling it from the shell" ) # unlazyload all modules as otherwise the overloading can happen too late unlazyload_modules() # list of certainly allowed IPs allowed_IPs = set( [ "localhost", "127.0.0.1", ] ) vals = [ v.split("//")[1].split(":")[0] if "//" in v else v.split(":")[0] for v in list(allowed.values()) ] [allowed_IPs.add(v) for v in vals] # list of probably allowed IPs private_ranges = [ ("10.0.0.0", "10.255.255.255"), ("172.16.0", "172.31.255.255"), ("192.168.0.0", "192.168.255.255"), ("127.0.0.0", "127.255.255.255"), ] @memoize def is_private(ip: str) -> bool: "detect if the connection would go to our computer or to a remote server" if ip in allowed_IPs: return True ip = int.from_bytes(socket.inet_aton(ip), "big") if ip in allowed_IPs: return True for start, end in private_ranges: if ( int.from_bytes(socket.inet_aton(start), "big") <= ip <= int.from_bytes(socket.inet_aton(end), "big") ): return True return False def create_connection(address, *args, **kwargs): "overload socket.create_connection to forbid outgoing connections" ip = socket.gethostbyname(address[0]) if not is_private(ip): raise RuntimeError("Network connections to the open internet are blocked") return socket._original_create_connection(address, *args, **kwargs) socket.socket = lambda *args, **kwargs: None socket._original_create_connection = socket.create_connection socket.create_connection = create_connection # sanity check assert is_private("localhost") assert is_private("10.0.1.32") assert is_private("192.168.2.35") assert is_private("127.12.13.15") # checking allowed ips are okay for v in vals: assert is_private(v), f"An address failed to be set as private: '{v}'" for al in list(allowed.values()): ip = socket.gethostbyname(al) assert is_private(ip), f"An address failed to be set as private: '{al}'" try: ip = socket.gethostbyname("www.google.com") skip = False except Exception as err: logger.warning( "Failed to get IP address of www.google.com to check if it is " "indeed blocked. You probably did this on purpose so not " f"crashing. Error: '{err}'" ) skip = True if not skip: assert not is_private(ip), ( f"Failed to set www.google.com as unreachable: IP is '{ip}'" ) def set_func_signature(func: Callable) -> Callable: """dynamically set the extra args of wdoc.__init__ so that instead of **cli_kwargs the signature indicates all allowed arguments. Needed to get correct behavior from fire.Fire '--completion'""" original_sig = inspect.signature(func) assert ( list(original_sig.parameters.values())[-1].kind == inspect.Parameter.VAR_KEYWORD ) new_params = list(original_sig.parameters.values())[:-1] # Remove **cli_kwargs new_params.extend( [ inspect.Parameter( name=arg, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, default=None, ) for arg, hint in extra_args_types.items() ] ) new_sig = original_sig.replace(parameters=new_params) @wraps(func) def new_func(self, *args, **kwargs): return func(self, *args, **kwargs) new_func.__signature__ = new_sig new_func.__annotations__ = get_type_hints(func) | extra_args_types return new_func # Tag constants THIN = "<think>" THINE = "</think>" ANSW = "<answer>" ANSWE = "</answer>" # Pre-compiled regex patterns _THIN_REGEX = re.compile(f"{re.escape(THIN)}(.*){re.escape(THINE)}", re.DOTALL) _THIN_SUB_REGEX = re.compile( f"{re.escape(THIN)}|{re.escape(THINE)}|{re.escape(ANSW)}|{re.escape(ANSWE)}" ) def thinking_answer_parser(output: str, strict: bool = False) -> dict: """separate the <think> and <answer> tags in an answer""" from copy import copy orig = copy(output) try: # some models like the geminis don't return their thinking output, sometimes # by mistake they keep thinking anyway so we get THINE without THIN. Let's just add # it at the beginning of output if THINE in output and THIN not in output: output = THIN + output # some models can consider that <answer> implies </thinking> so we # add it manually if ( THIN in output and ANSW in output and ANSWE in output and THINE not in output ): output = output.replace(ANSW, THINE + "\n" + ANSW) if (THIN not in output) and (ANSW not in output): assert THINE not in output, ( f"Output contains no {THIN} nor {ANSW} but an unexpected {THINE}:\n'''\n{output}\n'''" ) assert ANSWE not in output, ( f"Output contains no {THIN} nor {ANSW} but an unexpected {ANSWE}:\n'''\n{output}\n'''" ) logger.debug(f"LLM output contained neither {THIN} nor {ANSW}") return {"thinking": "", "answer": output} thinking = "" if ( THIN in output and THINE in output ): # meaning we found the expected <think> </think> block thinking_match = _THIN_REGEX.search(output) if thinking_match: thinking = thinking_match.group(1) # if not (THIN not in thinking and THINE not in thinking): # logger.warning( # f"Found {THIN} or {THINE} inside the thinking block, we don't expect nested thinkings but will proceed anyway." # ) else: # check we don't have only one of the xml sides assert THIN not in output and THINE not in output, ( f"Found only one of '{THIN}' or '{THINE}' in LLM output" ) logger.debug("LLM output contained no thinking block") answer = "" if ( ANSW in output and ANSWE in output ): # meaning we found the expected <answer> </answer> block # Create a version without the thinking part answer_text = output if thinking: answer_text = re.sub(re.escape(thinking), "", answer_text) # Remove the xml sides answer = _THIN_SUB_REGEX.sub("", answer_text) logger.debug("LLM output contained answer block") else: # check we don't have only one of the xml sides assert ANSW not in output and ANSWE not in output, ( f"Found only one of '{ANSW}' or '{ANSWE}' in LLM output" ) if thinking: logger.debug( "LLM output contained no answer block, assuming it's all but the thinking" ) answer = ( output.replace(thinking, "").replace(THIN, "").replace(THINE, "") ) else: logger.debug( "LLM output contained no answer block, assuming it's all the output" ) answer = output output = output.strip() thinking = thinking.strip() answer = answer.rstrip() assert THIN not in answer, ( f"Parsed answer contained unexpected {THIN}:\n'''\n{output}\n'''" ) assert THINE not in answer, ( f"Parsed answer contained unexpected {THINE}:\n'''\n{output}\n'''" ) assert ANSW not in answer, ( f"Parsed answer contained unexpected {ANSW}:\n'''\n{output}\n'''" ) assert ANSWE not in answer, ( f"Parsed answer contained unexpected {ANSWE}:\n'''\n{output}\n'''" ) assert answer, f"No answer could be parsed from LLM output: '{output}'" return {"thinking": thinking, "answer": answer} except Exception as err: if ( strict ): # otherwise combining answers could snowball into losing lots of text raise logger.exception( f"Error when parsing LLM output to get thinking and answer part.\nError: '{err}'\nOriginal output: '{orig}'\nNote: if the output seems fine but ends abruptly instead of by </answer> you might want to tweak the max_token settings.\nWill continue if not using --debug" ) if env.WDOC_DEBUG: raise else: assert output.strip(), "LLM output was empty" return { "thinking": "", "answer": f""" <note> The following LLM answer might have had a problem during parsing </note> <output> {orig} </output> """.strip(), } # this will contain wdoc's version to be used by langfuse's callback without circular imports langfuse_callback_holder = [] def create_langfuse_callback(version: str) -> None: assert not env.WDOC_PRIVATE_MODE # replace langfuse's env variable if set for wdoc, this is already done in env.py but doing it here also at runtime for k in [ "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY", "LANGFUSE_HOST", ]: newk = "WDOC_" + k if newk in os.environ and os.environ[newk]: os.environ[k] = os.environ[newk] if ( "LANGFUSE_PUBLIC_KEY" in os.environ and "LANGFUSE_SECRET_KEY" in os.environ and "LANGFUSE_HOST" in os.environ ): logger.debug("Activating langfuse callbacks") try: import langfuse except ImportError as e: if ( "WDOC_LANGFUSE_PUBLIC_KEY" in os.environ and "redacted" not in os.environ.get("WDOC_LANGFUSE_PUBLIC_KEY", "") ): raise Exception( "Couldn't import langfuse even though WDOC_LANGFUSE environment variables appear set. Crashing." ) from e else: logger.warning( f"Failed to setup langfuse callback because of ImportError, make sure package 'langfuse' is installed. The error was: '{e}'" ) try: # use litellm's callbacks for chatlitellm backend # # and use langchain's callback for openai's backend # BUT as of october 2024 it seems buggy with chatlitellm, the modelname does not seem to be passed? try: from langfuse.callback import CallbackHandler as LangfuseCallback litellm.success_callback.append("langfuse") litellm.failure_callback.append("langfuse") except (ImportError, AttributeError): # import changed for langfuse v3: https://github.com/langfuse/langfuse/issues/7205 from langfuse.langchain import CallbackHandler as LangfuseCallback # v3 switched to open telemetry: https://github.com/BerriAI/litellm/issues/11500 litellm.success_callback.append("langfuse_otel") litellm.failure_callback.append("langfuse_otel") langfuse_callback = LangfuseCallback( secret_key=os.environ["LANGFUSE_SECRET_KEY"], public_key=os.environ["LANGFUSE_PUBLIC_KEY"], host=os.environ["LANGFUSE_HOST"], session_id=str(uuid.uuid4()), version=version, ) langfuse_callback_holder.append(langfuse_callback) except Exception as e: logger.warning( f"Failed to setup langfuse callback, make sure package 'langfuse' is installed. The error was: '{e}'" ) @memoize def get_supported_model_params(modelname: ModelName) -> list: if modelname.is_testing(): return [] if modelname.backend == "openrouter": metadata = get_openrouter_metadata() assert modelname.original in metadata, ( f"Missing {modelname.original} from openrouter" ) return metadata[modelname.original]["supported_parameters"] for test in [ modelname.original, modelname.model, model_name_matcher(modelname.original), ]: params = litellm.get_supported_openai_params(test) if params: return params for test in [ modelname.original, modelname.model, model_name_matcher(modelname.original), ]: params = litellm.get_supported_openai_params( test, custom_llm_provider=modelname.backend ) if params: return params for test in [ modelname.original, modelname.model, model_name_matcher(modelname.original), ]: params = litellm.get_supported_openai_params( test, custom_llm_provider="openrouter" ) if params: return params return [] def cache_file_in_memory(file_path: Path, recursive: bool = False) -> bool: """ Advise the Linux kernel to cache the given file in memory. Args: file_path: Path to the file or directory to cache recursive: If True and file_path is a directory, cache all files within it Returns: bool: True if caching was successful, False otherwise """ # Check if we're on Linux import platform if platform.system() != "Linux": # This function only works on Linux systems. return False files_to_cache: List[Path] = [] # Handle directory case if file_path.is_dir(): if not recursive: # Warning: {file_path} is a directory. Set recursive=True to cache all files. return False # Collect all files recursively files_to_cache = [f for f in file_path.rglob("*") if f.is_file()] elif file_path.is_file(): files_to_cache = [file_path] else: # Error: {file_path} does not exist. return False success = True for file in files_to_cache: try: fd = os.open(str(file), os.O_RDONLY) os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_WILLNEED) os.close(fd) except Exception as e: logger.warning(f"Failed to cache {file}: {e}") success = False return success
[docs] def log_and_time_fn(fn: Callable) -> Callable: def wrapper(*args, **kwargs): logger.debug(f"Enterring {fn}") val = fn(*args, **kwargs) logger.debug(f"Exiting {fn}") return val wrapped = wraps(fn)(wrapper) return wrapped
[docs] def get_piped_input() -> Optional[Any]: """ Read data from stdin/pipes. This is done when importing wdoc, to avoid any issues with parallelism and threads etc. The content is added to the commandline starting wdoc directly in __main__.py. """ # Check if data is being piped (stdin is not a terminal) if not is_input_piped: return None # Save a copy of the original stdin for debugging # original_stdin = sys.stdin # Read the piped data piped_input = sys.stdin.buffer.read() try: piped_input = piped_input.decode() except Exception: pass # Create a new file descriptor for stdin from /dev/tty if available # This allows breakpoint() to work later try: if os.name != "nt": # Unix-like systems sys.stdin = open("/dev/tty") else: # Windows # On Windows this is trickier, consider using a different approach pass except Exception: # If we can't reopen stdin, at least return the data pass logger.debug("Loaded piped data") return piped_input
def open_anki_gui(query: str) -> None: """ Open Anki's card browser with the given query via AnkiConnect. Tries the optional `py_ankiconnect` dependency first (installed via the `anki` extra), and falls back to a plain `requests.post` call against the AnkiConnect endpoint when the package isn't available. The fallback honors the `PY_ANKICONNECT_DEFAULT_HOST` and `PY_ANKICONNECT_DEFAULT_PORT` env vars, defaulting to `http://127.0.0.1` and `8765`. """ try: from py_ankiconnect import PyAnkiconnect PyAnkiconnect()(action="guiBrowse", query=query) except ImportError: import requests host = os.environ.get("PY_ANKICONNECT_DEFAULT_HOST", "http://127.0.0.1") port = os.environ.get("PY_ANKICONNECT_DEFAULT_PORT", "8765") resp = requests.post( f"{host}:{port}", json={ "action": "guiBrowse", "params": {"query": query}, "version": 6, }, timeout=10, ) resp.raise_for_status() data = resp.json() if data.get("error") is not None: raise Exception(f"AnkiConnect returned error: '{data['error']}'") if pytest_ongoing: logger.warning("Detected that wdoc is run in a pytest environment")