from typing import Any, Callable, Dict, List, Optional from fastapi import HTTPException from sqlmodel import Session from clients.edge.base import EdgeClient from models.bot import BotInstance from providers.runtime.base import RuntimeProvider from providers.target import ProviderTarget, provider_target_to_dict class EdgeRuntimeProvider(RuntimeProvider): def __init__( self, *, read_provider_target: Callable[[str], ProviderTarget], resolve_edge_client: Callable[[ProviderTarget], EdgeClient], read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]], resolve_env_params: Callable[[str], Dict[str, str]], read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]], read_node_metadata: Callable[[str], Dict[str, Any]], ) -> None: self._read_provider_target = read_provider_target self._resolve_edge_client = resolve_edge_client self._read_runtime_snapshot = read_runtime_snapshot self._resolve_env_params = resolve_env_params self._read_bot_channels = read_bot_channels self._read_node_metadata = read_node_metadata async def start_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]: bot_id = str(bot.id or "").strip() if not bot_id: raise HTTPException(status_code=400, detail="Bot id is required") if not bool(getattr(bot, "enabled", True)): raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.") runtime_snapshot = self._read_runtime_snapshot(bot) target = self._read_provider_target(bot_id) client = self._client_for_target(target) node_runtime_overrides = self._node_runtime_overrides(target.node_id, target.runtime_kind) workspace_runtime = { **dict(runtime_snapshot), **provider_target_to_dict(target), **node_runtime_overrides, } client.sync_bot_workspace( bot_id=bot_id, channels_override=self._read_bot_channels(bot), global_delivery_override={ "sendProgress": bool(runtime_snapshot.get("send_progress")), "sendToolHints": bool(runtime_snapshot.get("send_tool_hints")), }, runtime_overrides=workspace_runtime, ) result = await client.start_bot( bot=bot, start_payload={ "image_tag": bot.image_tag, "runtime_kind": target.runtime_kind, "env_vars": self._resolve_env_params(bot_id), "cpu_cores": runtime_snapshot.get("cpu_cores"), "memory_mb": runtime_snapshot.get("memory_mb"), "storage_gb": runtime_snapshot.get("storage_gb"), **node_runtime_overrides, }, ) bot.docker_status = "RUNNING" session.add(bot) session.commit() return result def stop_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]: bot_id = str(bot.id or "").strip() if not bot_id: raise HTTPException(status_code=400, detail="Bot id is required") if not bool(getattr(bot, "enabled", True)): raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.") result = self._client_for_bot(bot_id).stop_bot(bot=bot) bot.docker_status = "STOPPED" session.add(bot) session.commit() return result def deliver_command(self, *, bot_id: str, command: str, media: Optional[List[str]] = None) -> Optional[str]: return self._client_for_bot(bot_id).deliver_command(bot_id=bot_id, command=command, media=media) def get_recent_logs(self, *, bot_id: str, tail: int = 300) -> List[str]: return self._client_for_bot(bot_id).get_recent_logs(bot_id=bot_id, tail=tail) def ensure_monitor(self, *, bot_id: str) -> bool: return bool(self._client_for_bot(bot_id).ensure_monitor(bot_id=bot_id)) def get_monitor_packets(self, *, bot_id: str, after_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]: return list(self._client_for_bot(bot_id).get_monitor_packets(bot_id=bot_id, after_seq=after_seq, limit=limit) or []) def get_runtime_status(self, *, bot_id: str) -> str: return str(self._client_for_bot(bot_id).get_runtime_status(bot_id=bot_id) or "STOPPED").upper() def get_resource_snapshot(self, *, bot_id: str) -> Dict[str, Any]: return dict(self._client_for_bot(bot_id).get_resource_snapshot(bot_id=bot_id) or {}) def _client_for_bot(self, bot_id: str) -> EdgeClient: target = self._read_provider_target(bot_id) return self._client_for_target(target) def _client_for_target(self, target: ProviderTarget) -> EdgeClient: if target.transport_kind != "edge": raise HTTPException(status_code=400, detail=f"edge runtime provider requires edge transport, got {target.transport_kind}") return self._resolve_edge_client(target) def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]: metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {}) payload: Dict[str, str] = {} workspace_root = str(metadata.get("workspace_root") or "").strip() if workspace_root: payload["workspace_root"] = workspace_root if str(runtime_kind or "").strip().lower() != "native": return payload native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode")) if native_sandbox_mode != "inherit": payload["native_sandbox_mode"] = native_sandbox_mode native_command = str(metadata.get("native_command") or "").strip() native_workdir = str(metadata.get("native_workdir") or "").strip() if native_command: payload["native_command"] = native_command if native_workdir: payload["native_workdir"] = native_workdir return payload @staticmethod def _normalize_native_sandbox_mode(raw_value: Any) -> str: text = str(raw_value or "").strip().lower() if text in {"workspace", "sandbox", "strict"}: return "workspace" if text in {"full_access", "full-access", "danger-full-access", "escape"}: return "full_access" return "inherit"