dashboard-nanobot/backend/providers/provision/edge.py

106 lines
4.7 KiB
Python

from typing import Any, Callable, Dict, List, Optional
from fastapi import HTTPException
from sqlmodel import Session
from clients.edge.base import EdgeClient
from models.bot import BotInstance
from providers.provision.base import ProvisionProvider
from providers.target import ProviderTarget
class EdgeProvisionProvider(ProvisionProvider):
def __init__(
self,
*,
read_provider_target: Callable[[str], ProviderTarget],
resolve_edge_client: Callable[[ProviderTarget], EdgeClient],
read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]],
read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]],
read_node_metadata: Callable[[str], Dict[str, Any]],
) -> None:
self._read_provider_target = read_provider_target
self._resolve_edge_client = resolve_edge_client
self._read_runtime_snapshot = read_runtime_snapshot
self._read_bot_channels = read_bot_channels
self._read_node_metadata = read_node_metadata
def sync_bot_workspace(
self,
*,
session: Session,
bot_id: str,
channels_override: Optional[List[Dict[str, Any]]] = None,
global_delivery_override: Optional[Dict[str, Any]] = None,
runtime_overrides: Optional[Dict[str, Any]] = None,
) -> None:
bot = session.get(BotInstance, bot_id)
if bot is None:
raise HTTPException(status_code=404, detail="Bot not found")
snapshot = dict(self._read_runtime_snapshot(bot))
merged_runtime = dict(snapshot)
if isinstance(runtime_overrides, dict):
for key, value in runtime_overrides.items():
if key in {"api_key", "llm_provider", "llm_model"}:
text = str(value or "").strip()
if not text:
continue
merged_runtime[key] = text
continue
if key == "api_base":
merged_runtime[key] = str(value or "").strip()
continue
merged_runtime[key] = value
target = self._read_provider_target(bot_id)
merged_runtime.update(self._node_runtime_overrides(target.node_id, target.runtime_kind))
resolved_delivery = dict(global_delivery_override or {})
if "sendProgress" not in resolved_delivery:
resolved_delivery["sendProgress"] = bool(merged_runtime.get("send_progress", False))
if "sendToolHints" not in resolved_delivery:
resolved_delivery["sendToolHints"] = bool(merged_runtime.get("send_tool_hints", False))
self._client_for_target(target).sync_bot_workspace(
bot_id=bot_id,
channels_override=channels_override if channels_override is not None else self._read_bot_channels(bot),
global_delivery_override=resolved_delivery,
runtime_overrides=merged_runtime,
)
def _client_for_bot(self, bot_id: str) -> EdgeClient:
target = self._read_provider_target(bot_id)
return self._client_for_target(target)
def _client_for_target(self, target: ProviderTarget) -> EdgeClient:
if target.transport_kind != "edge":
raise HTTPException(status_code=400, detail=f"edge provision provider requires edge transport, got {target.transport_kind}")
return self._resolve_edge_client(target)
def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]:
metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {})
payload: Dict[str, str] = {}
workspace_root = str(metadata.get("workspace_root") or "").strip()
if workspace_root:
payload["workspace_root"] = workspace_root
if str(runtime_kind or "").strip().lower() != "native":
return payload
native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode"))
if native_sandbox_mode != "inherit":
payload["native_sandbox_mode"] = native_sandbox_mode
native_command = str(metadata.get("native_command") or "").strip()
native_workdir = str(metadata.get("native_workdir") or "").strip()
if native_command:
payload["native_command"] = native_command
if native_workdir:
payload["native_workdir"] = native_workdir
return payload
@staticmethod
def _normalize_native_sandbox_mode(raw_value: Any) -> str:
text = str(raw_value or "").strip().lower()
if text in {"workspace", "sandbox", "strict"}:
return "workspace"
if text in {"full_access", "full-access", "danger-full-access", "escape"}:
return "full_access"
return "inherit"