"""LiteLLM provider implementation for multi-provider support.""" import ast import hashlib import json import os import secrets import string from typing import Any import json_repair import litellm from litellm import acompletion from loguru import logger from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.registry import find_by_model, find_gateway # Standard chat-completion message keys. _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) _ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"}) _ALNUM = string.ascii_letters + string.digits _ARG_PATCH_VERBOSE_VALUES = {"1", "true", "yes", "on"} def _short_tool_id() -> str: """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" return "".join(secrets.choice(_ALNUM) for _ in range(9)) def _should_log_argument_patch() -> bool: value = str(os.getenv("DASHBOARD_LITELLM_PATCH_VERBOSE") or "").strip().lower() return value in _ARG_PATCH_VERBOSE_VALUES def _coerce_tool_arguments_json(value: Any) -> str: """Return provider-safe JSON text for OpenAI-style function.arguments.""" if value is None: return "{}" if isinstance(value, str): text = value.strip() if not text: return "{}" try: json.loads(text) return text except Exception: pass try: parsed = ast.literal_eval(text) except Exception: parsed = None else: try: return json.dumps(parsed, ensure_ascii=False) except Exception: pass return json.dumps({"raw": text}, ensure_ascii=False) try: return json.dumps(value, ensure_ascii=False) except Exception: return json.dumps({"raw": str(value)}, ensure_ascii=False) class LiteLLMProvider(LLMProvider): """ LLM provider using LiteLLM for multi-provider support. Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through a unified interface. Provider-specific logic is driven by the registry (see providers/registry.py) — no if-elif chains needed here. """ def __init__( self, api_key: str | None = None, api_base: str | None = None, default_model: str = "anthropic/claude-opus-4-5", extra_headers: dict[str, str] | None = None, provider_name: str | None = None, ): super().__init__(api_key, api_base) self.default_model = default_model self.extra_headers = extra_headers or {} # Detect gateway / local deployment. # provider_name (from config key) is the primary signal; # api_key / api_base are fallback for auto-detection. self._gateway = find_gateway(provider_name, api_key, api_base) # Configure environment variables if api_key: self._setup_env(api_key, api_base, default_model) if api_base: litellm.api_base = api_base # Disable LiteLLM logging noise litellm.suppress_debug_info = True # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) litellm.drop_params = True self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY")) def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: """Set environment variables based on detected provider.""" spec = self._gateway or find_by_model(model) if not spec: return if not spec.env_key: # OAuth/provider-only specs (for example: openai_codex) return # Gateway/local overrides existing env; standard provider doesn't if self._gateway: os.environ[spec.env_key] = api_key else: os.environ.setdefault(spec.env_key, api_key) # Resolve env_extras placeholders: # {api_key} → user's API key # {api_base} → user's api_base, falling back to spec.default_api_base effective_base = api_base or spec.default_api_base for env_name, env_val in spec.env_extras: resolved = env_val.replace("{api_key}", api_key) resolved = resolved.replace("{api_base}", effective_base) os.environ.setdefault(env_name, resolved) def _resolve_model(self, model: str) -> str: """Resolve model name by applying provider/gateway prefixes.""" if self._gateway: prefix = self._gateway.litellm_prefix if self._gateway.strip_model_prefix: model = model.split("/")[-1] if prefix: model = f"{prefix}/{model}" return model # Standard mode: auto-prefix for known providers spec = find_by_model(model) if spec and spec.litellm_prefix: model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix) if not any(model.startswith(s) for s in spec.skip_prefixes): model = f"{spec.litellm_prefix}/{model}" return model @staticmethod def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: """Normalize explicit provider prefixes like `github-copilot/...`.""" if "/" not in model: return model prefix, remainder = model.split("/", 1) if prefix.lower().replace("-", "_") != spec_name: return model return f"{canonical_prefix}/{remainder}" def _supports_cache_control(self, model: str) -> bool: """Return True when the provider supports cache_control on content blocks.""" if self._gateway is not None: return self._gateway.supports_prompt_caching spec = find_by_model(model) return spec is not None and spec.supports_prompt_caching def _apply_cache_control( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: """Return copies of messages and tools with cache_control injected.""" new_messages = [] for msg in messages: if msg.get("role") == "system": content = msg["content"] if isinstance(content, str): new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}] else: new_content = list(content) new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}} new_messages.append({**msg, "content": new_content}) else: new_messages.append(msg) new_tools = tools if tools: new_tools = list(tools) new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}} return new_messages, new_tools def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: """Apply model-specific parameter overrides from the registry.""" model_lower = model.lower() spec = find_by_model(model) if spec: for pattern, overrides in spec.model_overrides: if pattern in model_lower: kwargs.update(overrides) return @staticmethod def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]: """Return provider-specific extra keys to preserve in request messages.""" spec = find_by_model(original_model) or find_by_model(resolved_model) if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"): return _ANTHROPIC_EXTRA_KEYS return frozenset() @staticmethod def _normalize_tool_call_id(tool_call_id: Any) -> Any: """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" if not isinstance(tool_call_id, str): return tool_call_id if len(tool_call_id) == 9 and tool_call_id.isalnum(): return tool_call_id return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] @staticmethod def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: """Strip non-standard keys and ensure assistant messages have a content key.""" allowed = _ALLOWED_MSG_KEYS | extra_keys sanitized = LLMProvider._sanitize_request_messages(messages, allowed) id_map: dict[str, str] = {} patched_arguments = 0 def map_id(value: Any) -> Any: if not isinstance(value, str): return value return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) for clean in sanitized: # Keep assistant tool_calls[].id and tool tool_call_id in sync after # shortening, otherwise strict providers reject the broken linkage. if isinstance(clean.get("tool_calls"), list): normalized_tool_calls = [] for tc in clean["tool_calls"]: if not isinstance(tc, dict): normalized_tool_calls.append(tc) continue tc_clean = dict(tc) tc_clean["id"] = map_id(tc_clean.get("id")) function = tc_clean.get("function") if isinstance(function, dict) and "arguments" in function: function_clean = dict(function) original_arguments = function_clean.get("arguments") normalized_arguments = _coerce_tool_arguments_json(original_arguments) if original_arguments != normalized_arguments: patched_arguments += 1 function_clean["arguments"] = normalized_arguments tc_clean["function"] = function_clean normalized_tool_calls.append(tc_clean) clean["tool_calls"] = normalized_tool_calls if "tool_call_id" in clean and clean["tool_call_id"]: clean["tool_call_id"] = map_id(clean["tool_call_id"]) if patched_arguments and _should_log_argument_patch(): logger.info( "Normalized {} historical tool/function argument payload(s) to JSON strings for LiteLLM", patched_arguments, ) return sanitized async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: """ Send a chat completion request via LiteLLM. Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions in OpenAI format. model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5'). max_tokens: Maximum tokens in response. temperature: Sampling temperature. Returns: LLMResponse with content and/or tool calls. """ original_model = model or self.default_model model = self._resolve_model(original_model) extra_msg_keys = self._extra_msg_keys(original_model, model) if self._supports_cache_control(original_model): messages, tools = self._apply_cache_control(messages, tools) # Clamp max_tokens to at least 1 — negative or zero values cause # LiteLLM to reject the request with "max_tokens must be at least 1". max_tokens = max(1, max_tokens) kwargs: dict[str, Any] = { "model": model, "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys), "max_tokens": max_tokens, "temperature": temperature, } if self._gateway: kwargs.update(self._gateway.litellm_kwargs) # Apply model-specific overrides (e.g. kimi-k2.5 temperature) self._apply_model_overrides(model, kwargs) if self._langsmith_enabled: kwargs.setdefault("callbacks", []).append("langsmith") # Pass api_key directly — more reliable than env vars alone if self.api_key: kwargs["api_key"] = self.api_key # Pass api_base for custom endpoints if self.api_base: kwargs["api_base"] = self.api_base # Pass extra headers (e.g. APP-Code for AiHubMix) if self.extra_headers: kwargs["extra_headers"] = self.extra_headers if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort kwargs["drop_params"] = True if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" try: response = await acompletion(**kwargs) return self._parse_response(response) except Exception as e: # Return error as content for graceful handling return LLMResponse( content=f"Error calling LLM: {str(e)}", finish_reason="error", ) def _parse_response(self, response: Any) -> LLMResponse: """Parse LiteLLM response into our standard format.""" choice = response.choices[0] message = choice.message content = message.content finish_reason = choice.finish_reason # Some providers (e.g. GitHub Copilot) split content and tool_calls # across multiple choices. Merge them so tool_calls are not lost. raw_tool_calls = [] for ch in response.choices: msg = ch.message if hasattr(msg, "tool_calls") and msg.tool_calls: raw_tool_calls.extend(msg.tool_calls) if ch.finish_reason in ("tool_calls", "stop"): finish_reason = ch.finish_reason if not content and msg.content: content = msg.content if len(response.choices) > 1: logger.debug("LiteLLM response has {} choices, merged {} tool_calls", len(response.choices), len(raw_tool_calls)) tool_calls = [] for tc in raw_tool_calls: # Parse arguments from JSON string if needed args = tc.function.arguments if isinstance(args, str): args = json_repair.loads(args) provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None function_provider_specific_fields = ( getattr(tc.function, "provider_specific_fields", None) or None ) tool_calls.append(ToolCallRequest( id=_short_tool_id(), name=tc.function.name, arguments=args, provider_specific_fields=provider_specific_fields, function_provider_specific_fields=function_provider_specific_fields, )) usage = {} if hasattr(response, "usage") and response.usage: usage = { "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens, } reasoning_content = getattr(message, "reasoning_content", None) or None thinking_blocks = getattr(message, "thinking_blocks", None) or None return LLMResponse( content=content, tool_calls=tool_calls, finish_reason=finish_reason or "stop", usage=usage, reasoning_content=reasoning_content, thinking_blocks=thinking_blocks, ) def get_default_model(self) -> str: """Get the default model.""" return self.default_model