"""
Code related to loading the LLM instance, with an appropriate price
counting callback.
"""
import os
from beartype.typing import Any, Dict, List, Optional, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.caches import BaseCache
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.base import BaseMessage
from langchain_core.outputs.llm_result import LLMResult
from loguru import logger
from wdoc.utils.env import env
from wdoc.utils.misc import ModelName, get_model_max_tokens, langfuse_callback_holder
# lorem ipsum is output by "testing" llm
LOREM_IPSUM = (
"Lorem ipsum dolor sit amet, consectetur adipiscing "
"elit, sed do eiusmod tempor incididunt ut labore et "
"dolore magna aliqua. Ut enim ad minim veniam, quis "
"nostrud exercitation ullamco laboris nisi ut aliquip "
"ex ea commodo consequat. Duis aute irure dolor in "
"reprehenderit in voluptate velit esse cillum dolore eu "
"fugiat nulla pariatur. Excepteur sint occaecat cupidatat "
"non proident, sunt in culpa qui officia deserunt mollit "
"anim id est laborum."
)
def load_llm(
modelname: ModelName,
llm_verbosity: bool,
llm_cache: Union[None, bool, BaseCache],
api_base: Optional[str],
private: bool,
tags: List[str],
**extra_model_args,
) -> Union[
"langchain_litellm.ChatLiteLLM",
"langchain_community.chat_models.fake.FakeListChatModel",
]:
"""load language model"""
import litellm
from langchain_litellm import ChatLiteLLM
from langchain_community.chat_models.fake import FakeListChatModel
if extra_model_args is None:
extra_model_args = {}
# setup tags for langfuse or others
if tags is None:
tags = []
if "wdoc" not in tags:
tags.append("wdoc")
if "tags" in extra_model_args:
assert isinstance(extra_model_args["tags"], list)
for t in extra_model_args["tags"]:
if t not in tags:
tags.append(t)
env_tags = env.WDOC_LITELLM_TAGS
if env_tags:
if isinstance(env_tags, str):
env_tags = env_tags.split(",")
assert isinstance(env_tags, list), env_tags
tags.extend(env_tags)
assert "cache" not in extra_model_args
if modelname.is_testing():
assert modelname.original == "testing/testing"
if llm_verbosity:
logger.info("Loading a fake LLM using the testing/ backend")
llm = FakeListChatModel(
verbose=llm_verbosity,
responses=[f"Fake answer n°{i}: {LOREM_IPSUM}" for i in range(1, 100)],
callbacks=[PriceCountingCallback(verbose=llm_verbosity)]
+ langfuse_callback_holder,
disable_streaming=True, # Not needed and might break cache
cache=False,
**extra_model_args,
)
return llm
else:
assert "testing" not in modelname.original.lower()
logger.debug("Loading model via litellm")
if private:
assert api_base, "If private is set, api_base must be set too"
assert env.WDOC_PRIVATE_MODE
assert "WDOC_PRIVATE_MODE" in os.environ, (
"Missing env variable WDOC_PRIVATE_MODE"
)
assert os.environ["WDOC_PRIVATE_MODE"] == "true", (
"Wrong value for env variable WDOC_PRIVATE_MODE"
)
if api_base:
logger.warning(f"Will use custom api_base {api_base}")
if not (
f"{modelname.backend.upper()}_API_KEY" in os.environ
and os.environ[f"{modelname.backend.upper()}_API_KEY"]
):
if modelname.backend == "ollama":
logger.debug(
"No OLLAMA_API_KEY env var found but ignoring because using ollama"
)
pass
elif not api_base:
raise Exception(
f"No environment variable named {modelname.backend.upper()}_API_KEY found"
)
else:
logger.debug(
f"No environment variable named {modelname.backend.upper()}_API_KEY found. Continuing because some setups are fine with this."
)
# extra check for private mode
if private:
assert os.environ["WDOC_PRIVATE_MODE"] == "true"
logger.warning(
f"private is on so overwriting {modelname.backend.upper()}_API_KEY from environment variables"
)
assert (
os.environ[f"{modelname.backend.upper()}_API_KEY"]
== "REDACTED_BECAUSE_WDOC_IN_PRIVATE_MODE"
)
assert (
os.environ[f"{modelname.backend.upper()}_API_KEY"]
== "REDACTED_BECAUSE_WDOC_IN_PRIVATE_MODE"
)
assert not langfuse_callback_holder, (
"Private argument but langfuse_handler appears set. Something went wrong so crashing just to be safe."
)
else:
assert not env.WDOC_PRIVATE_MODE
assert (
"WDOC_PRIVATE_MODE" not in os.environ
or os.environ["WDOC_PRIVATE_MODE"] == "false"
)
assert private == env.WDOC_PRIVATE_MODE
max_tokens = get_model_max_tokens(modelname)
logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}")
if "max_tokens" not in extra_model_args:
# intentionaly limiting max tokens because it can cause bugs
if modelname.backend != "ollama":
extra_model_args["max_tokens"] = int(max_tokens * 0.9)
else:
if max_tokens <= 10_000:
extra_model_args["max_tokens"] = int(max_tokens * 0.9)
else:
logger.debug(
f"Detected an ollama model with large max_tokens ({max_tokens}), they usually overestimate their context window capabilities so we reduce it if the user does not specify a max_tokens kwarg"
)
extra_model_args["max_tokens"] = int(max(max_tokens * 0.2, 4096))
# if using openrouter, prioritise groq for speed and sort by throughput
# source: https://openrouter.ai/docs/features/provider-routing#ordering-specific-providers
if modelname.backend == "openrouter":
ort_model_kwargs = extra_model_args.get("model_kwargs", {})
ort_provider_kwargs = ort_model_kwargs.get("provider", {})
ort_order = ort_provider_kwargs.get("order", [])
ort_order = ort_order if ort_order else ["groq"]
ort_sort = ort_provider_kwargs.get("sort", "")
ort_sort = ort_sort if ort_sort else "throughput"
ort_provider_kwargs["order"] = ort_order
ort_provider_kwargs["sort"] = ort_sort
ort_model_kwargs["provider"] = ort_provider_kwargs
extra_model_args["model_kwargs"] = ort_model_kwargs
logger.debug(
"Using openrouter so sorting models by throughput and prioritizing groq out of principle"
)
logger.debug(f"Using ChatLiteLLM backend for model {modelname.original}")
llm = ChatLiteLLM(
model_name=modelname.original,
disable_streaming=True, # Not needed and might break cache
api_base=api_base,
cache=llm_cache,
verbose=llm_verbosity,
tags=tags,
callbacks=[PriceCountingCallback(verbose=llm_verbosity)]
+ langfuse_callback_holder,
user=env.WDOC_LITELLM_USER,
**extra_model_args,
)
litellm.drop_params = True
if private:
assert llm.api_base, "private is set but no api_base for llm were found"
assert llm.api_base == api_base, (
"private is set but found unexpected llm.api_base value: '{litellm.api_base}'"
)
# fix: the SQLiteCache's str appearance is cancelling its own cache lookup!
# Simply turning '<wdoc.utils.customs.fix_llm_caching.SQLiteCacheFixed object at 0x72cd5a6d4050>' into '<wdoc.utils.customs.fix_llm_caching.SQLiteCacheFixed object at REDACTED>'
# edit: normally should be fixed as of https://github.com/langchain-ai/langchain/pull/32156/files
if llm.cache:
cur = str(llm.cache)
if "object at " in cur:
logger.debug(
"The SQLiteCache of langchain is broken so we will modify ths __repr__ and __str__ of its class"
)
fixed = cur.split(" at ")[0] + "REDACTED"
llm.cache.__class__.__repr__ = lambda _=None: fixed
llm.cache.__class__.__str__ = lambda _=None: fixed
# also try to make it so that setting extra model params does not reuse the cache
if extra_model_args:
cur_string = llm._get_llm_string()
alt_string = llm._get_llm_string(extra_model_args)
if cur_string != alt_string:
logger.debug(
"Modifying the llm._get_llm_string because langchain broke it over a year ago..."
)
llm._get_llm_string = lambda *args, **kwargs: alt_string
cur = str(llm._get_llm_string)
if "object at " in cur:
fixed = cur.split(" at ")[0] + "REDACTED"
llm._get_llm_string.__class__.__repr__ = lambda _=None: fixed
llm._get_llm_string.__class__.__str__ = lambda _=None: fixed
logger.debug(
f"Extra model args there were used for {modelname.original}: '{extra_model_args}'"
)
return llm
[docs]
class PriceCountingCallback(BaseCallbackHandler):
"source: https://python.langchain.com/docs/modules/callbacks/"
def __init__(self, verbose, *args, **kwargs):
super().__init__(*args, **kwargs)
self.verbose = verbose
self.total_tokens = 0
self.prompt_tokens = 0
self.completion_tokens = 0
self.internal_reasoning_tokens = 0
self.methods_called = []
self.authorized_methods = [
"on_llm_start",
"on_chat_model_start",
"on_llm_end",
"on_llm_error",
"on_chain_start",
"on_chain_end",
"on_chain_error",
]
self.pbar = []
def __repr__(self) -> str:
# setting __repr__ and __str__ is important because it can
# maybe be used for caching?
return "PriceCountingCallback"
def __str__(self) -> str:
return "PriceCountingCallback"
def _check_methods_called(self) -> bool:
assert all(meth in dir(self) for meth in self.methods_called), (
"unexpected method names!"
)
wrong = [
meth for meth in self.methods_called if meth not in self.authorized_methods
]
if wrong:
raise Exception(f"Unauthorized_method were called: {','.join(wrong)}")
return True
[docs]
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
self.methods_called.append("on_llm_start")
if self.verbose:
logger.debug(serialized)
logger.debug(prompts)
logger.debug(kwargs)
self._check_methods_called()
[docs]
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> Any:
"""Run when Chat Model starts running."""
self.methods_called.append("on_chat_model_start")
if self.verbose:
logger.debug(serialized)
logger.debug(messages)
logger.debug(kwargs)
self._check_methods_called()
[docs]
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""
self.methods_called.append("on_llm_end")
if self.verbose:
logger.debug(response)
logger.debug(kwargs)
if response.llm_output is None or response.llm_output["token_usage"] is None:
if self.verbose:
logger.debug("None llm_output, returning.")
return
new_p = response.llm_output["token_usage"]["prompt_tokens"]
new_c = response.llm_output["token_usage"]["completion_tokens"]
new_r = response.llm_output["token_usage"]["total_tokens"] - new_p - new_c
self.prompt_tokens += new_p
self.completion_tokens += new_c
self.total_tokens += new_p + new_c + new_r
assert (
self.total_tokens
== self.prompt_tokens
+ self.completion_tokens
+ self.internal_reasoning_tokens
)
self._check_methods_called()
[docs]
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when LLM errors."""
self.methods_called.append("on_llm_error")
if self.verbose:
logger.debug(error)
logger.debug(kwargs)
self._check_methods_called()
[docs]
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
"""Run when chain starts running."""
if self.verbose:
logger.debug(serialized)
logger.debug(inputs)
logger.debug(kwargs)
self.methods_called.append("on_chain_start")
self._check_methods_called()
[docs]
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
if self.verbose:
logger.debug(outputs)
logger.debug(kwargs)
self.methods_called.append("on_chain_end")
self._check_methods_called()
if self.pbar:
self.pbar[-1].update(1)
[docs]
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
"""Run when chain errors."""
if self.verbose:
logger.debug(error)
logger.debug(kwargs)
self.methods_called.append("on_chain_error")
self._check_methods_called()
[docs]
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
self.methods_called.append("on_llm_new_token")
self._check_methods_called()
raise NotImplementedError("Not expecting streaming")
[docs]
def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
self.methods_called.append("on_text")
self._check_methods_called()
raise NotImplementedError("Not expecting to call self.on_text")
[docs]
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.methods_called.append("on_agent_action")
self._check_methods_called()
raise NotImplementedError("Not expecting agent call")
[docs]
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
self.methods_called.append("on_agent_finish")
self._check_methods_called()
raise NotImplementedError("Not expecting agent call")