106 lines
4.7 KiB
Python
106 lines
4.7 KiB
Python
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from sqlmodel import Session
|
|
|
|
from clients.edge.base import EdgeClient
|
|
from models.bot import BotInstance
|
|
from providers.provision.base import ProvisionProvider
|
|
from providers.target import ProviderTarget
|
|
|
|
|
|
class EdgeProvisionProvider(ProvisionProvider):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
read_provider_target: Callable[[str], ProviderTarget],
|
|
resolve_edge_client: Callable[[ProviderTarget], EdgeClient],
|
|
read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]],
|
|
read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]],
|
|
read_node_metadata: Callable[[str], Dict[str, Any]],
|
|
) -> None:
|
|
self._read_provider_target = read_provider_target
|
|
self._resolve_edge_client = resolve_edge_client
|
|
self._read_runtime_snapshot = read_runtime_snapshot
|
|
self._read_bot_channels = read_bot_channels
|
|
self._read_node_metadata = read_node_metadata
|
|
|
|
def sync_bot_workspace(
|
|
self,
|
|
*,
|
|
session: Session,
|
|
bot_id: str,
|
|
channels_override: Optional[List[Dict[str, Any]]] = None,
|
|
global_delivery_override: Optional[Dict[str, Any]] = None,
|
|
runtime_overrides: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if bot is None:
|
|
raise HTTPException(status_code=404, detail="Bot not found")
|
|
snapshot = dict(self._read_runtime_snapshot(bot))
|
|
merged_runtime = dict(snapshot)
|
|
if isinstance(runtime_overrides, dict):
|
|
for key, value in runtime_overrides.items():
|
|
if key in {"api_key", "llm_provider", "llm_model"}:
|
|
text = str(value or "").strip()
|
|
if not text:
|
|
continue
|
|
merged_runtime[key] = text
|
|
continue
|
|
if key == "api_base":
|
|
merged_runtime[key] = str(value or "").strip()
|
|
continue
|
|
merged_runtime[key] = value
|
|
target = self._read_provider_target(bot_id)
|
|
merged_runtime.update(self._node_runtime_overrides(target.node_id, target.runtime_kind))
|
|
|
|
resolved_delivery = dict(global_delivery_override or {})
|
|
if "sendProgress" not in resolved_delivery:
|
|
resolved_delivery["sendProgress"] = bool(merged_runtime.get("send_progress", False))
|
|
if "sendToolHints" not in resolved_delivery:
|
|
resolved_delivery["sendToolHints"] = bool(merged_runtime.get("send_tool_hints", False))
|
|
|
|
self._client_for_target(target).sync_bot_workspace(
|
|
bot_id=bot_id,
|
|
channels_override=channels_override if channels_override is not None else self._read_bot_channels(bot),
|
|
global_delivery_override=resolved_delivery,
|
|
runtime_overrides=merged_runtime,
|
|
)
|
|
|
|
def _client_for_bot(self, bot_id: str) -> EdgeClient:
|
|
target = self._read_provider_target(bot_id)
|
|
return self._client_for_target(target)
|
|
|
|
def _client_for_target(self, target: ProviderTarget) -> EdgeClient:
|
|
if target.transport_kind != "edge":
|
|
raise HTTPException(status_code=400, detail=f"edge provision provider requires edge transport, got {target.transport_kind}")
|
|
return self._resolve_edge_client(target)
|
|
|
|
def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]:
|
|
metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {})
|
|
payload: Dict[str, str] = {}
|
|
workspace_root = str(metadata.get("workspace_root") or "").strip()
|
|
if workspace_root:
|
|
payload["workspace_root"] = workspace_root
|
|
if str(runtime_kind or "").strip().lower() != "native":
|
|
return payload
|
|
native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode"))
|
|
if native_sandbox_mode != "inherit":
|
|
payload["native_sandbox_mode"] = native_sandbox_mode
|
|
native_command = str(metadata.get("native_command") or "").strip()
|
|
native_workdir = str(metadata.get("native_workdir") or "").strip()
|
|
if native_command:
|
|
payload["native_command"] = native_command
|
|
if native_workdir:
|
|
payload["native_workdir"] = native_workdir
|
|
return payload
|
|
|
|
@staticmethod
|
|
def _normalize_native_sandbox_mode(raw_value: Any) -> str:
|
|
text = str(raw_value or "").strip().lower()
|
|
if text in {"workspace", "sandbox", "strict"}:
|
|
return "workspace"
|
|
if text in {"full_access", "full-access", "danger-full-access", "escape"}:
|
|
return "full_access"
|
|
return "inherit"
|