dashboard-nanobot/backend/core/speech_service.py

260 lines
9.4 KiB
Python

from __future__ import annotations
import os
import shutil
import subprocess
import tempfile
import threading
from pathlib import Path
from typing import Any, Dict, Optional
from core.settings import (
STT_AUDIO_FILTER,
STT_AUDIO_PREPROCESS,
STT_DEVICE,
STT_ENABLED,
STT_FORCE_SIMPLIFIED,
STT_INITIAL_PROMPT,
STT_MAX_AUDIO_SECONDS,
STT_MODEL,
STT_MODEL_DIR,
)
class SpeechServiceError(RuntimeError):
pass
class SpeechDisabledError(SpeechServiceError):
pass
class SpeechDurationError(SpeechServiceError):
pass
class WhisperSpeechService:
def __init__(self) -> None:
self._model: Any = None
self._model_source: str = ""
self._backend: str = ""
self._model_lock = threading.Lock()
def _resolve_model_source(self) -> str:
model = str(STT_MODEL or "").strip()
model_dir = str(STT_MODEL_DIR or "").strip()
if not model:
raise SpeechServiceError(
"STT_MODEL is empty. Please set the full model file name, e.g. ggml-samll-q8_0.bin."
)
# If STT_MODEL itself is an absolute/relative path, use it directly.
if any(sep in model for sep in ("/", "\\")):
direct = Path(model).expanduser()
if not direct.exists() or not direct.is_file():
raise SpeechServiceError(f"STT model file not found: {direct}")
if direct.suffix.lower() != ".bin":
raise SpeechServiceError(
"STT_MODEL must point to a whisper.cpp ggml .bin model file."
)
return str(direct.resolve())
# Strict mode: only exact filename, no alias/auto detection.
if Path(model).suffix.lower() != ".bin":
raise SpeechServiceError(
"STT_MODEL must be the exact model file name (with .bin), e.g. ggml-small-q8_0.bin."
)
if not model_dir:
raise SpeechServiceError("STT_MODEL_DIR is empty.")
root = Path(model_dir).expanduser()
if not root.exists() or not root.is_dir():
raise SpeechServiceError(f"STT_MODEL_DIR does not exist: {root}")
candidate = root / model
if not candidate.exists() or not candidate.is_file():
raise SpeechServiceError(
f"STT model file not found under STT_MODEL_DIR: {candidate}"
)
return str(candidate.resolve())
def _load_model(self) -> Any:
model_source = self._resolve_model_source()
if self._model is not None and self._model_source == model_source:
return self._model
with self._model_lock:
if self._model is not None and self._model_source == model_source:
return self._model
try:
from pywhispercpp.model import Model # type: ignore
except Exception as exc:
raise SpeechServiceError(
"pywhispercpp is not installed in the active backend environment. "
"Run pip install -r backend/requirements.txt or rebuild the backend image."
) from exc
self._model = Model(
model_source,
print_realtime=False,
print_progress=False,
)
self._backend = "pywhispercpp"
self._model_source = model_source
return self._model
@staticmethod
def _preprocess_audio(file_path: str) -> str:
target = str(file_path or "").strip()
if not STT_AUDIO_PREPROCESS or not target or not os.path.isfile(target):
return target
if shutil.which("ffmpeg") is None:
return target
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav", prefix=".speech_clean_")
tmp_path = tmp.name
tmp.close()
cmd = [
"ffmpeg",
"-y",
"-i",
target,
"-vn",
"-ac",
"1",
"-ar",
"16000",
]
audio_filter = str(STT_AUDIO_FILTER or "").strip()
if audio_filter:
cmd.extend(["-af", audio_filter])
cmd.extend(["-c:a", "pcm_s16le", tmp_path])
try:
completed = subprocess.run(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
)
if completed.returncode != 0 or not os.path.exists(tmp_path) or os.path.getsize(tmp_path) <= 0:
if os.path.exists(tmp_path):
os.remove(tmp_path)
return target
return tmp_path
except Exception:
if os.path.exists(tmp_path):
os.remove(tmp_path)
return target
@staticmethod
def _probe_audio_duration_seconds(file_path: str) -> Optional[float]:
try:
import av # type: ignore
with av.open(file_path) as container:
if container.duration is not None:
# container.duration is in av.time_base units.
return max(0.0, float(container.duration / av.time_base))
for stream in container.streams:
if stream.type != "audio":
continue
if stream.duration is not None and stream.time_base is not None:
return max(0.0, float(stream.duration * stream.time_base))
except Exception:
return None
return None
@staticmethod
def _normalize_text(text: str) -> str:
content = str(text or "").strip()
if not content or not STT_FORCE_SIMPLIFIED:
return content
try:
from opencc_purepy import OpenCC # type: ignore
return str(OpenCC("t2s").convert(content) or "").strip() or content
except Exception:
return content
@staticmethod
def _filter_supported_transcribe_kwargs(model: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
try:
available = set(model.get_params().keys())
except Exception:
return kwargs
return {key: value for key, value in kwargs.items() if key in available}
def transcribe_file(self, file_path: str, language: Optional[str] = None) -> Dict[str, Any]:
if not STT_ENABLED:
raise SpeechDisabledError("Speech-to-text is disabled")
target = str(file_path or "").strip()
if not target or not os.path.isfile(target):
raise SpeechServiceError("Audio file not found")
duration_seconds = self._probe_audio_duration_seconds(target)
if duration_seconds is not None and duration_seconds > float(STT_MAX_AUDIO_SECONDS) + 0.3:
raise SpeechDurationError(f"Audio duration exceeds {STT_MAX_AUDIO_SECONDS} seconds")
prepared_target = self._preprocess_audio(target)
try:
model = self._load_model()
lang = str(language or "").strip().lower()
normalized_lang: Optional[str] = None
if lang and lang not in {"auto", "null", "none"}:
normalized_lang = lang
max_end = 0.0
detected_language = ""
texts = []
kwargs: Dict[str, Any] = {
"print_realtime": False,
"print_progress": False,
"no_context": True,
"suppress_non_speech_tokens": True,
}
if normalized_lang:
kwargs["language"] = normalized_lang
initial_prompt = str(STT_INITIAL_PROMPT or "").strip()
if initial_prompt:
kwargs["initial_prompt"] = initial_prompt
kwargs = self._filter_supported_transcribe_kwargs(model, kwargs)
try:
segments = model.transcribe(prepared_target, **kwargs)
except Exception as exc:
raise SpeechServiceError(
f"pywhispercpp transcription failed: {exc}. "
"If input is not wav, install ffmpeg in runtime image."
) from exc
for segment in segments:
txt = str(getattr(segment, "text", "") or "").strip()
if txt:
texts.append(txt)
if normalized_lang:
detected_language = normalized_lang
try:
max_end = max(max_end, float(getattr(segment, "t1", 0.0) or 0.0) / 100.0)
except Exception:
pass
if max_end > float(STT_MAX_AUDIO_SECONDS) + 0.3:
raise SpeechDurationError(f"Audio duration exceeds {STT_MAX_AUDIO_SECONDS} seconds")
text = self._normalize_text(" ".join(texts).strip())
if not text:
raise SpeechServiceError("No speech detected")
if duration_seconds is None:
duration_seconds = max_end if max_end > 0 else None
return {
"text": text,
"language": detected_language or None,
"duration_seconds": duration_seconds,
"max_audio_seconds": STT_MAX_AUDIO_SECONDS,
"model": STT_MODEL,
"device": STT_DEVICE,
"backend": self._backend or "unknown",
}
finally:
if prepared_target != target and os.path.exists(prepared_target):
try:
os.remove(prepared_target)
except Exception:
pass