from typing import Any, Callable, Dict, List, Optional from fastapi import HTTPException from sqlmodel import Session from clients.edge.base import EdgeClient from models.bot import BotInstance from providers.provision.base import ProvisionProvider from providers.target import ProviderTarget class EdgeProvisionProvider(ProvisionProvider): def __init__( self, *, read_provider_target: Callable[[str], ProviderTarget], resolve_edge_client: Callable[[ProviderTarget], EdgeClient], read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]], read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]], read_node_metadata: Callable[[str], Dict[str, Any]], ) -> None: self._read_provider_target = read_provider_target self._resolve_edge_client = resolve_edge_client self._read_runtime_snapshot = read_runtime_snapshot self._read_bot_channels = read_bot_channels self._read_node_metadata = read_node_metadata def sync_bot_workspace( self, *, session: Session, bot_id: str, channels_override: Optional[List[Dict[str, Any]]] = None, global_delivery_override: Optional[Dict[str, Any]] = None, runtime_overrides: Optional[Dict[str, Any]] = None, ) -> None: bot = session.get(BotInstance, bot_id) if bot is None: raise HTTPException(status_code=404, detail="Bot not found") snapshot = dict(self._read_runtime_snapshot(bot)) merged_runtime = dict(snapshot) if isinstance(runtime_overrides, dict): for key, value in runtime_overrides.items(): if key in {"api_key", "llm_provider", "llm_model"}: text = str(value or "").strip() if not text: continue merged_runtime[key] = text continue if key == "api_base": merged_runtime[key] = str(value or "").strip() continue merged_runtime[key] = value target = self._read_provider_target(bot_id) merged_runtime.update(self._node_runtime_overrides(target.node_id, target.runtime_kind)) resolved_delivery = dict(global_delivery_override or {}) if "sendProgress" not in resolved_delivery: resolved_delivery["sendProgress"] = bool(merged_runtime.get("send_progress", False)) if "sendToolHints" not in resolved_delivery: resolved_delivery["sendToolHints"] = bool(merged_runtime.get("send_tool_hints", False)) self._client_for_target(target).sync_bot_workspace( bot_id=bot_id, channels_override=channels_override if channels_override is not None else self._read_bot_channels(bot), global_delivery_override=resolved_delivery, runtime_overrides=merged_runtime, ) def _client_for_bot(self, bot_id: str) -> EdgeClient: target = self._read_provider_target(bot_id) return self._client_for_target(target) def _client_for_target(self, target: ProviderTarget) -> EdgeClient: if target.transport_kind != "edge": raise HTTPException(status_code=400, detail=f"edge provision provider requires edge transport, got {target.transport_kind}") return self._resolve_edge_client(target) def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]: metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {}) payload: Dict[str, str] = {} workspace_root = str(metadata.get("workspace_root") or "").strip() if workspace_root: payload["workspace_root"] = workspace_root if str(runtime_kind or "").strip().lower() != "native": return payload native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode")) if native_sandbox_mode != "inherit": payload["native_sandbox_mode"] = native_sandbox_mode native_command = str(metadata.get("native_command") or "").strip() native_workdir = str(metadata.get("native_workdir") or "").strip() if native_command: payload["native_command"] = native_command if native_workdir: payload["native_workdir"] = native_workdir return payload @staticmethod def _normalize_native_sandbox_mode(raw_value: Any) -> str: text = str(raw_value or "").strip().lower() if text in {"workspace", "sandbox", "strict"}: return "workspace" if text in {"full_access", "full-access", "danger-full-access", "escape"}: return "full_access" return "inherit"