diff --git a/bot-images/Dashboard.Dockerfile b/bot-images/Dashboard.Dockerfile index 2351729..4fe17d0 100644 --- a/bot-images/Dashboard.Dockerfile +++ b/bot-images/Dashboard.Dockerfile @@ -3,7 +3,6 @@ ENV PYTHONUNBUFFERED=1 ENV LANG=C.UTF-8 ENV LC_ALL=C.UTF-8 ENV PYTHONIOENCODING=utf-8 -ENV PYTHONPATH=/opt/dashboard-patches${PYTHONPATH:+:${PYTHONPATH}} # 1. 替换 Debian 源为国内镜像 RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \ @@ -20,155 +19,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ RUN python -m pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ --upgrade \ pip setuptools wheel aiohttp -# 3.1 LiteLLM compatibility patch for DashScope coder/code models. -# DashScope coder models require `tool_calls[*].function.arguments` to be a JSON string. -# Some upstream stacks may replay historical tool calls as dicts / Python-literal strings. -# We patch LiteLLM entrypoints at runtime so old history can still be forwarded safely. -RUN mkdir -p /opt/dashboard-patches && cat > /opt/dashboard-patches/sitecustomize.py <<'PY' -import ast -import copy -import json -import os -import sys -from typing import Any - - -def _log(message: str) -> None: - if str(os.getenv("DASHBOARD_LITELLM_PATCH_VERBOSE") or "").strip().lower() not in {"1", "true", "yes", "on"}: - return - print(f"[dashboard-litellm-patch] {message}", file=sys.stderr, flush=True) - - -def _coerce_json_arguments(value: Any) -> str: - if value is None: - return "{}" - if isinstance(value, str): - text = value.strip() - if not text: - return "{}" - try: - json.loads(text) - return text - except Exception: - pass - try: - parsed = ast.literal_eval(text) - except Exception: - parsed = None - else: - try: - return json.dumps(parsed, ensure_ascii=False) - except Exception: - pass - return json.dumps({"raw": text}, ensure_ascii=False) - try: - return json.dumps(value, ensure_ascii=False) - except Exception: - return json.dumps({"raw": str(value)}, ensure_ascii=False) - - -def _sanitize_openai_messages(messages: Any) -> tuple[Any, int]: - if not isinstance(messages, list): - return messages, 0 - try: - cloned = copy.deepcopy(messages) - except Exception: - cloned = list(messages) - - changed = 0 - for message in cloned: - if not isinstance(message, dict): - continue - - tool_calls = message.get("tool_calls") - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if not isinstance(tool_call, dict): - continue - function = tool_call.get("function") - if not isinstance(function, dict): - continue - arguments = function.get("arguments") - normalized = _coerce_json_arguments(arguments) - if arguments != normalized: - function["arguments"] = normalized - changed += 1 - - function_call = message.get("function_call") - if isinstance(function_call, dict): - arguments = function_call.get("arguments") - normalized = _coerce_json_arguments(arguments) - if arguments != normalized: - function_call["arguments"] = normalized - changed += 1 - - return cloned, changed - - -def _patch_litellm() -> None: - try: - import litellm # type: ignore - except Exception as exc: - _log(f"litellm import skipped: {exc}") - return - - def _wrap_sync(fn): - if not callable(fn) or getattr(fn, "_dashboard_litellm_patch", False): - return fn - - def wrapper(*args, **kwargs): - messages = kwargs.get("messages") - normalized_messages, changed = _sanitize_openai_messages(messages) - if changed: - kwargs["messages"] = normalized_messages - _log(f"sanitized {changed} tool/function argument payload(s) before sync completion") - return fn(*args, **kwargs) - - setattr(wrapper, "_dashboard_litellm_patch", True) - return wrapper - - def _wrap_async(fn): - if not callable(fn) or getattr(fn, "_dashboard_litellm_patch", False): - return fn - - async def wrapper(*args, **kwargs): - messages = kwargs.get("messages") - normalized_messages, changed = _sanitize_openai_messages(messages) - if changed: - kwargs["messages"] = normalized_messages - _log(f"sanitized {changed} tool/function argument payload(s) before async completion") - return await fn(*args, **kwargs) - - setattr(wrapper, "_dashboard_litellm_patch", True) - return wrapper - - for attr in ("completion", "completion_with_retries"): - if hasattr(litellm, attr): - setattr(litellm, attr, _wrap_sync(getattr(litellm, attr))) - - for attr in ("acompletion",): - if hasattr(litellm, attr): - setattr(litellm, attr, _wrap_async(getattr(litellm, attr))) - - try: - import litellm.main as litellm_main # type: ignore - except Exception: - litellm_main = None - - if litellm_main is not None: - for attr in ("completion", "completion_with_retries"): - if hasattr(litellm_main, attr): - setattr(litellm_main, attr, _wrap_sync(getattr(litellm_main, attr))) - for attr in ("acompletion",): - if hasattr(litellm_main, attr): - setattr(litellm_main, attr, _wrap_async(getattr(litellm_main, attr))) - - _log("LiteLLM monkey patch installed") - - -_patch_litellm() -PY - WORKDIR /app # 这一步会把您修改好的 nanobot/channels/dashboard.py 一起拷进去 COPY . /app diff --git a/bot-images/litellm_provider.py b/bot-images/litellm_provider.py new file mode 100644 index 0000000..e2e0f10 --- /dev/null +++ b/bot-images/litellm_provider.py @@ -0,0 +1,410 @@ +"""LiteLLM provider implementation for multi-provider support.""" + +import ast +import hashlib +import json +import os +import secrets +import string +from typing import Any + +import json_repair +import litellm +from litellm import acompletion +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.registry import find_by_model, find_gateway + +# Standard chat-completion message keys. +_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) +_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"}) +_ALNUM = string.ascii_letters + string.digits +_ARG_PATCH_VERBOSE_VALUES = {"1", "true", "yes", "on"} + +def _short_tool_id() -> str: + """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) + + +def _should_log_argument_patch() -> bool: + value = str(os.getenv("DASHBOARD_LITELLM_PATCH_VERBOSE") or "").strip().lower() + return value in _ARG_PATCH_VERBOSE_VALUES + + +def _coerce_tool_arguments_json(value: Any) -> str: + """Return provider-safe JSON text for OpenAI-style function.arguments.""" + if value is None: + return "{}" + + if isinstance(value, str): + text = value.strip() + if not text: + return "{}" + try: + json.loads(text) + return text + except Exception: + pass + try: + parsed = ast.literal_eval(text) + except Exception: + parsed = None + else: + try: + return json.dumps(parsed, ensure_ascii=False) + except Exception: + pass + return json.dumps({"raw": text}, ensure_ascii=False) + + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps({"raw": str(value)}, ensure_ascii=False) + + +class LiteLLMProvider(LLMProvider): + """ + LLM provider using LiteLLM for multi-provider support. + + Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through + a unified interface. Provider-specific logic is driven by the registry + (see providers/registry.py) — no if-elif chains needed here. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "anthropic/claude-opus-4-5", + extra_headers: dict[str, str] | None = None, + provider_name: str | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + + # Detect gateway / local deployment. + # provider_name (from config key) is the primary signal; + # api_key / api_base are fallback for auto-detection. + self._gateway = find_gateway(provider_name, api_key, api_base) + + # Configure environment variables + if api_key: + self._setup_env(api_key, api_base, default_model) + + if api_base: + litellm.api_base = api_base + + # Disable LiteLLM logging noise + litellm.suppress_debug_info = True + # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) + litellm.drop_params = True + + self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY")) + + def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: + """Set environment variables based on detected provider.""" + spec = self._gateway or find_by_model(model) + if not spec: + return + if not spec.env_key: + # OAuth/provider-only specs (for example: openai_codex) + return + + # Gateway/local overrides existing env; standard provider doesn't + if self._gateway: + os.environ[spec.env_key] = api_key + else: + os.environ.setdefault(spec.env_key, api_key) + + # Resolve env_extras placeholders: + # {api_key} → user's API key + # {api_base} → user's api_base, falling back to spec.default_api_base + effective_base = api_base or spec.default_api_base + for env_name, env_val in spec.env_extras: + resolved = env_val.replace("{api_key}", api_key) + resolved = resolved.replace("{api_base}", effective_base) + os.environ.setdefault(env_name, resolved) + + def _resolve_model(self, model: str) -> str: + """Resolve model name by applying provider/gateway prefixes.""" + if self._gateway: + prefix = self._gateway.litellm_prefix + if self._gateway.strip_model_prefix: + model = model.split("/")[-1] + if prefix: + model = f"{prefix}/{model}" + return model + + # Standard mode: auto-prefix for known providers + spec = find_by_model(model) + if spec and spec.litellm_prefix: + model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix) + if not any(model.startswith(s) for s in spec.skip_prefixes): + model = f"{spec.litellm_prefix}/{model}" + + return model + + @staticmethod + def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: + """Normalize explicit provider prefixes like `github-copilot/...`.""" + if "/" not in model: + return model + prefix, remainder = model.split("/", 1) + if prefix.lower().replace("-", "_") != spec_name: + return model + return f"{canonical_prefix}/{remainder}" + + def _supports_cache_control(self, model: str) -> bool: + """Return True when the provider supports cache_control on content blocks.""" + if self._gateway is not None: + return self._gateway.supports_prompt_caching + spec = find_by_model(model) + return spec is not None and spec.supports_prompt_caching + + def _apply_cache_control( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + """Return copies of messages and tools with cache_control injected.""" + new_messages = [] + for msg in messages: + if msg.get("role") == "system": + content = msg["content"] + if isinstance(content, str): + new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}] + else: + new_content = list(content) + new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}} + new_messages.append({**msg, "content": new_content}) + else: + new_messages.append(msg) + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}} + + return new_messages, new_tools + + def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: + """Apply model-specific parameter overrides from the registry.""" + model_lower = model.lower() + spec = find_by_model(model) + if spec: + for pattern, overrides in spec.model_overrides: + if pattern in model_lower: + kwargs.update(overrides) + return + + @staticmethod + def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]: + """Return provider-specific extra keys to preserve in request messages.""" + spec = find_by_model(original_model) or find_by_model(resolved_model) + if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"): + return _ANTHROPIC_EXTRA_KEYS + return frozenset() + + @staticmethod + def _normalize_tool_call_id(tool_call_id: Any) -> Any: + """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" + if not isinstance(tool_call_id, str): + return tool_call_id + if len(tool_call_id) == 9 and tool_call_id.isalnum(): + return tool_call_id + return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] + + @staticmethod + def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: + """Strip non-standard keys and ensure assistant messages have a content key.""" + allowed = _ALLOWED_MSG_KEYS | extra_keys + sanitized = LLMProvider._sanitize_request_messages(messages, allowed) + id_map: dict[str, str] = {} + patched_arguments = 0 + + def map_id(value: Any) -> Any: + if not isinstance(value, str): + return value + return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) + + for clean in sanitized: + # Keep assistant tool_calls[].id and tool tool_call_id in sync after + # shortening, otherwise strict providers reject the broken linkage. + if isinstance(clean.get("tool_calls"), list): + normalized_tool_calls = [] + for tc in clean["tool_calls"]: + if not isinstance(tc, dict): + normalized_tool_calls.append(tc) + continue + tc_clean = dict(tc) + tc_clean["id"] = map_id(tc_clean.get("id")) + function = tc_clean.get("function") + if isinstance(function, dict) and "arguments" in function: + function_clean = dict(function) + original_arguments = function_clean.get("arguments") + normalized_arguments = _coerce_tool_arguments_json(original_arguments) + if original_arguments != normalized_arguments: + patched_arguments += 1 + function_clean["arguments"] = normalized_arguments + tc_clean["function"] = function_clean + normalized_tool_calls.append(tc_clean) + clean["tool_calls"] = normalized_tool_calls + + if "tool_call_id" in clean and clean["tool_call_id"]: + clean["tool_call_id"] = map_id(clean["tool_call_id"]) + + if patched_arguments and _should_log_argument_patch(): + logger.info( + "Normalized {} historical tool/function argument payload(s) to JSON strings for LiteLLM", + patched_arguments, + ) + return sanitized + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + """ + Send a chat completion request via LiteLLM. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: Optional list of tool definitions in OpenAI format. + model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5'). + max_tokens: Maximum tokens in response. + temperature: Sampling temperature. + + Returns: + LLMResponse with content and/or tool calls. + """ + original_model = model or self.default_model + model = self._resolve_model(original_model) + extra_msg_keys = self._extra_msg_keys(original_model, model) + + if self._supports_cache_control(original_model): + messages, tools = self._apply_cache_control(messages, tools) + + # Clamp max_tokens to at least 1 — negative or zero values cause + # LiteLLM to reject the request with "max_tokens must be at least 1". + max_tokens = max(1, max_tokens) + + kwargs: dict[str, Any] = { + "model": model, + "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys), + "max_tokens": max_tokens, + "temperature": temperature, + } + + if self._gateway: + kwargs.update(self._gateway.litellm_kwargs) + + # Apply model-specific overrides (e.g. kimi-k2.5 temperature) + self._apply_model_overrides(model, kwargs) + + if self._langsmith_enabled: + kwargs.setdefault("callbacks", []).append("langsmith") + + # Pass api_key directly — more reliable than env vars alone + if self.api_key: + kwargs["api_key"] = self.api_key + + # Pass api_base for custom endpoints + if self.api_base: + kwargs["api_base"] = self.api_base + + # Pass extra headers (e.g. APP-Code for AiHubMix) + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + kwargs["drop_params"] = True + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + + try: + response = await acompletion(**kwargs) + return self._parse_response(response) + except Exception as e: + # Return error as content for graceful handling + return LLMResponse( + content=f"Error calling LLM: {str(e)}", + finish_reason="error", + ) + + def _parse_response(self, response: Any) -> LLMResponse: + """Parse LiteLLM response into our standard format.""" + choice = response.choices[0] + message = choice.message + content = message.content + finish_reason = choice.finish_reason + + # Some providers (e.g. GitHub Copilot) split content and tool_calls + # across multiple choices. Merge them so tool_calls are not lost. + raw_tool_calls = [] + for ch in response.choices: + msg = ch.message + if hasattr(msg, "tool_calls") and msg.tool_calls: + raw_tool_calls.extend(msg.tool_calls) + if ch.finish_reason in ("tool_calls", "stop"): + finish_reason = ch.finish_reason + if not content and msg.content: + content = msg.content + + if len(response.choices) > 1: + logger.debug("LiteLLM response has {} choices, merged {} tool_calls", + len(response.choices), len(raw_tool_calls)) + + tool_calls = [] + for tc in raw_tool_calls: + # Parse arguments from JSON string if needed + args = tc.function.arguments + if isinstance(args, str): + args = json_repair.loads(args) + + provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None + function_provider_specific_fields = ( + getattr(tc.function, "provider_specific_fields", None) or None + ) + + tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + provider_specific_fields=provider_specific_fields, + function_provider_specific_fields=function_provider_specific_fields, + )) + + usage = {} + if hasattr(response, "usage") and response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + + reasoning_content = getattr(message, "reasoning_content", None) or None + thinking_blocks = getattr(message, "thinking_blocks", None) or None + + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason or "stop", + usage=usage, + reasoning_content=reasoning_content, + thinking_blocks=thinking_blocks, + ) + + def get_default_model(self) -> str: + """Get the default model.""" + return self.default_model