260 lines
9.7 KiB
Python
260 lines
9.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import tempfile
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
from core.settings import STT_DEVICE, STT_MODEL, STT_MODEL_DIR
|
|
from services.platform_service import get_speech_runtime_settings
|
|
|
|
|
|
class SpeechServiceError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class SpeechDisabledError(SpeechServiceError):
|
|
pass
|
|
|
|
|
|
class SpeechDurationError(SpeechServiceError):
|
|
pass
|
|
|
|
|
|
class WhisperSpeechService:
|
|
def __init__(self) -> None:
|
|
self._model: Any = None
|
|
self._model_source: str = ""
|
|
self._backend: str = ""
|
|
self._model_lock = threading.Lock()
|
|
|
|
def _resolve_model_source(self) -> str:
|
|
model = str(STT_MODEL or "").strip()
|
|
model_dir = str(STT_MODEL_DIR or "").strip()
|
|
|
|
if not model:
|
|
raise SpeechServiceError(
|
|
"STT_MODEL is empty. Please set the full model file name, e.g. ggml-samll-q8_0.bin."
|
|
)
|
|
|
|
# If STT_MODEL itself is an absolute/relative path, use it directly.
|
|
if any(sep in model for sep in ("/", "\\")):
|
|
direct = Path(model).expanduser()
|
|
if not direct.exists() or not direct.is_file():
|
|
raise SpeechServiceError(f"STT model file not found: {direct}")
|
|
if direct.suffix.lower() != ".bin":
|
|
raise SpeechServiceError(
|
|
"STT_MODEL must point to a whisper.cpp ggml .bin model file."
|
|
)
|
|
return str(direct.resolve())
|
|
|
|
# Strict mode: only exact filename, no alias/auto detection.
|
|
if Path(model).suffix.lower() != ".bin":
|
|
raise SpeechServiceError(
|
|
"STT_MODEL must be the exact model file name (with .bin), e.g. ggml-small-q8_0.bin."
|
|
)
|
|
|
|
if not model_dir:
|
|
raise SpeechServiceError("STT_MODEL_DIR is empty.")
|
|
root = Path(model_dir).expanduser()
|
|
if not root.exists() or not root.is_dir():
|
|
raise SpeechServiceError(f"STT_MODEL_DIR does not exist: {root}")
|
|
candidate = root / model
|
|
if not candidate.exists() or not candidate.is_file():
|
|
raise SpeechServiceError(
|
|
f"STT model file not found under STT_MODEL_DIR: {candidate}"
|
|
)
|
|
return str(candidate.resolve())
|
|
|
|
def reset_runtime(self) -> None:
|
|
with self._model_lock:
|
|
self._model = None
|
|
self._model_source = ""
|
|
self._backend = ""
|
|
|
|
def _load_model(self) -> Any:
|
|
model_source = self._resolve_model_source()
|
|
if self._model is not None and self._model_source == model_source:
|
|
return self._model
|
|
with self._model_lock:
|
|
if self._model is not None and self._model_source == model_source:
|
|
return self._model
|
|
try:
|
|
from pywhispercpp.model import Model # type: ignore
|
|
except Exception as exc:
|
|
raise SpeechServiceError(
|
|
"pywhispercpp is not installed in the active backend environment. "
|
|
"Run pip install -r backend/requirements.txt or rebuild the backend image."
|
|
) from exc
|
|
self._model = Model(
|
|
model_source,
|
|
print_realtime=False,
|
|
print_progress=False,
|
|
)
|
|
self._backend = "pywhispercpp"
|
|
self._model_source = model_source
|
|
return self._model
|
|
|
|
@staticmethod
|
|
def _preprocess_audio(file_path: str) -> str:
|
|
settings = get_speech_runtime_settings()
|
|
target = str(file_path or "").strip()
|
|
if not settings["audio_preprocess"] or not target or not os.path.isfile(target):
|
|
return target
|
|
if shutil.which("ffmpeg") is None:
|
|
return target
|
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav", prefix=".speech_clean_")
|
|
tmp_path = tmp.name
|
|
tmp.close()
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
target,
|
|
"-vn",
|
|
"-ac",
|
|
"1",
|
|
"-ar",
|
|
"16000",
|
|
]
|
|
audio_filter = str(settings["audio_filter"] or "").strip()
|
|
if audio_filter:
|
|
cmd.extend(["-af", audio_filter])
|
|
cmd.extend(["-c:a", "pcm_s16le", tmp_path])
|
|
try:
|
|
completed = subprocess.run(
|
|
cmd,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
check=False,
|
|
)
|
|
if completed.returncode != 0 or not os.path.exists(tmp_path) or os.path.getsize(tmp_path) <= 0:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
return target
|
|
return tmp_path
|
|
except Exception:
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
return target
|
|
|
|
@staticmethod
|
|
def _probe_audio_duration_seconds(file_path: str) -> Optional[float]:
|
|
try:
|
|
import av # type: ignore
|
|
|
|
with av.open(file_path) as container:
|
|
if container.duration is not None:
|
|
# container.duration is in av.time_base units.
|
|
return max(0.0, float(container.duration / av.time_base))
|
|
for stream in container.streams:
|
|
if stream.type != "audio":
|
|
continue
|
|
if stream.duration is not None and stream.time_base is not None:
|
|
return max(0.0, float(stream.duration * stream.time_base))
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
@staticmethod
|
|
def _normalize_text(text: str) -> str:
|
|
settings = get_speech_runtime_settings()
|
|
content = str(text or "").strip()
|
|
if not content or not settings["force_simplified"]:
|
|
return content
|
|
try:
|
|
from opencc_purepy import OpenCC # type: ignore
|
|
|
|
return str(OpenCC("t2s").convert(content) or "").strip() or content
|
|
except Exception:
|
|
return content
|
|
|
|
@staticmethod
|
|
def _filter_supported_transcribe_kwargs(model: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
try:
|
|
available = set(model.get_params().keys())
|
|
except Exception:
|
|
return kwargs
|
|
return {key: value for key, value in kwargs.items() if key in available}
|
|
|
|
def transcribe_file(self, file_path: str, language: Optional[str] = None) -> Dict[str, Any]:
|
|
settings = get_speech_runtime_settings()
|
|
if not settings["enabled"]:
|
|
raise SpeechDisabledError("Speech-to-text is disabled")
|
|
target = str(file_path or "").strip()
|
|
if not target or not os.path.isfile(target):
|
|
raise SpeechServiceError("Audio file not found")
|
|
|
|
duration_seconds = self._probe_audio_duration_seconds(target)
|
|
if duration_seconds is not None and duration_seconds > float(settings["max_audio_seconds"]) + 0.3:
|
|
raise SpeechDurationError(f"Audio duration exceeds {settings['max_audio_seconds']} seconds")
|
|
|
|
prepared_target = self._preprocess_audio(target)
|
|
try:
|
|
model = self._load_model()
|
|
lang = str(language or "").strip().lower()
|
|
normalized_lang: Optional[str] = None
|
|
if lang and lang not in {"auto", "null", "none"}:
|
|
normalized_lang = lang
|
|
|
|
max_end = 0.0
|
|
detected_language = ""
|
|
texts = []
|
|
kwargs: Dict[str, Any] = {
|
|
"print_realtime": False,
|
|
"print_progress": False,
|
|
"no_context": True,
|
|
"suppress_non_speech_tokens": True,
|
|
}
|
|
if normalized_lang:
|
|
kwargs["language"] = normalized_lang
|
|
initial_prompt = str(settings["initial_prompt"] or "").strip()
|
|
if initial_prompt:
|
|
kwargs["initial_prompt"] = initial_prompt
|
|
kwargs = self._filter_supported_transcribe_kwargs(model, kwargs)
|
|
try:
|
|
segments = model.transcribe(prepared_target, **kwargs)
|
|
except Exception as exc:
|
|
raise SpeechServiceError(
|
|
f"pywhispercpp transcription failed: {exc}. "
|
|
"If input is not wav, install ffmpeg in runtime image."
|
|
) from exc
|
|
for segment in segments:
|
|
txt = str(getattr(segment, "text", "") or "").strip()
|
|
if txt:
|
|
texts.append(txt)
|
|
if normalized_lang:
|
|
detected_language = normalized_lang
|
|
try:
|
|
max_end = max(max_end, float(getattr(segment, "t1", 0.0) or 0.0) / 100.0)
|
|
except Exception:
|
|
pass
|
|
if max_end > float(settings["max_audio_seconds"]) + 0.3:
|
|
raise SpeechDurationError(f"Audio duration exceeds {settings['max_audio_seconds']} seconds")
|
|
|
|
text = self._normalize_text(" ".join(texts).strip())
|
|
if not text:
|
|
raise SpeechServiceError("No speech detected")
|
|
|
|
if duration_seconds is None:
|
|
duration_seconds = max_end if max_end > 0 else None
|
|
|
|
return {
|
|
"text": text,
|
|
"language": detected_language or None,
|
|
"duration_seconds": duration_seconds,
|
|
"max_audio_seconds": settings["max_audio_seconds"],
|
|
"model": STT_MODEL,
|
|
"device": STT_DEVICE,
|
|
"backend": self._backend or "unknown",
|
|
}
|
|
finally:
|
|
if prepared_target != target and os.path.exists(prepared_target):
|
|
try:
|
|
os.remove(prepared_target)
|
|
except Exception:
|
|
pass
|