feat: STT model params

v3.2
zhangzhanwei 2025-09-25 15:45:23 +08:00 committed by zhanweizhang7
parent a2130edd08
commit e8c36a6223
20 changed files with 193 additions and 50 deletions

View File

@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer):
audio_list = serializers.ListField(required=True,
label=_("The audio file cannot be empty"))
model_params_setting = serializers.DictField(required=False,
label=_("Model parameter settings"))
class ISpeechToTextNode(INode):
@ -35,6 +37,6 @@ class ISpeechToTextNode(INode):
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, stt_model_id, chat_id,
audio,
audio, model_params_setting=None,
**kwargs) -> NodeResult:
pass

View File

@ -20,9 +20,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
workspace_id = self.workflow_manage.get_body().get('workspace_id')
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
audio_list = audio
self.context['audio_list'] = audio

View File

@ -965,7 +965,7 @@ class ApplicationOperateSerializer(serializers.Serializer):
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
'-create_time').first()
if application.stt_model_enable:
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id)
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting)
text = model.speech_to_text(instance.get('file'))
return text

View File

@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image"
msgstr ""
msgid "Update auth setting"
msgstr ""
msgid "If not passed, the default value is streaming_asr_demo"
msgstr ""
msgid "If not passed, the default value is 16000"
msgstr ""
msgid "Sample Rate"
msgstr ""

View File

@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "获取图片失败"
msgid "Update auth setting"
msgstr "更新认证设置"
msgstr "更新认证设置"
msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未传入,则默认值为 streaming_asr_demo"
msgid "If not passed, the default value is 16000"
msgstr "如果未传入,则默认值为 16000"
msgid "Sample Rate"
msgstr "采样率"

View File

@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "獲取圖片失敗"
msgid "Update auth setting"
msgstr "更新認證設置"
msgstr "更新認證設置"
msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未傳入,則預設值為 streaming_asr_demo"
msgid "If not passed, the default value is 16000"
msgstr "如果未傳入,則預設值為 16000"
msgid "Sample Rate"
msgstr "採樣率"

View File

@ -4,11 +4,21 @@ import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AliyunBaiLianSTTModelParams(BaseForm):
sample_rate = forms.SliderField(
TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')),
required=True,
default_value=16000,
_step=4000, _min=0, _max=20000,precision=0
)
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
@ -55,7 +65,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
@ -89,4 +99,4 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
:param model_name: Name of the model.
:return: Parameter setting form (not implemented).
"""
pass
return AliyunBaiLianSTTModelParams()

View File

@ -13,11 +13,14 @@ from models_provider.impl.base_stt import BaseSpeechToText
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def is_cache_model():
@ -33,6 +36,7 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
return AliyunBaiLianSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params,
)
@ -43,10 +47,17 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
def speech_to_text(self, audio_file):
dashscope.api_key = self.api_key
recognition = Recognition(model=self.model,
format='mp3',
sample_rate=16000,
callback=None)
recognition_params = {
'model': self.model,
'format': 'mp3',
'sample_rate': 16000,
'callback': None,
**self.params
}
print(recognition_params)
recognition = Recognition(**recognition_params)
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中
temp_file.write(audio_file.read())

View File

@ -29,7 +29,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()

View File

@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
api_version: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version')
self.params = kwargs.get('params')
@staticmethod
def is_cache_model():
@ -41,6 +43,7 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'),
params=model_kwargs,
**optional_params,
)
@ -62,5 +65,13 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params)
return res.text

View File

@ -6,10 +6,17 @@ from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAISTTModelParams(BaseForm):
language = forms.TextInputField(
TooltipLabel(_('language'), _('If not passed, the default value is zh')),
required=True,
default_value='zh',
)
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
@ -28,7 +35,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
@ -46,4 +53,5 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass
return OpenAISTTModelParams()

View File

@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
model: str
params: dict
@staticmethod
def is_cache_model():
@ -27,6 +28,8 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
@ -39,6 +42,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
params = model_kwargs,
**optional_params,
)
@ -58,6 +62,14 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text

View File

@ -28,7 +28,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()

View File

@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
@ -35,6 +37,7 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params,
)
@ -58,5 +61,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text

View File

@ -53,11 +53,14 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
base_url=base_url
)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': audio_file,
'language': 'zh',
}
result = client.audio.transcriptions.create(
file=audio_file,
model=self.model,
language=self.params.get('Language'),
response_format="json"
**transcription_params, extra_body=filter_params
)
return result.text

View File

@ -6,9 +6,17 @@ from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineSTTModelParams(BaseForm):
uid = forms.TextInputField(
TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')),
required=True,
default_value='streaming_asr_demo'
)
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API URL', required=True,
@ -31,7 +39,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
@ -49,4 +57,4 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
def get_model_params_setting_form(self, model_name):
pass
return VolcanicEngineSTTModelParams()

View File

@ -192,6 +192,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
volcanic_cluster: str
volcanic_api_url: str
volcanic_token: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -199,6 +200,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
self.volcanic_token = kwargs.get('volcanic_token')
self.volcanic_app_id = kwargs.get('volcanic_app_id')
self.volcanic_cluster = kwargs.get('volcanic_cluster')
self.params = kwargs.get('params')
@staticmethod
def is_cache_model():
@ -216,10 +218,14 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
volcanic_token=model_credential.get('volcanic_token'),
volcanic_app_id=model_credential.get('volcanic_app_id'),
volcanic_cluster=model_credential.get('volcanic_cluster'),
params=model_kwargs,
**model_kwargs,
**optional_params
)
def construct_request(self, reqid):
params = self.params or {}
req = {
'app': {
'appid': self.volcanic_app_id,
@ -227,24 +233,24 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
'token': self.volcanic_token,
},
'user': {
'uid': 'uid'
'uid': params.get("uid", "streaming_asr_demo")
},
'request': {
'reqid': reqid,
'nbest': self.nbest,
'workflow': self.workflow,
'show_language': self.show_language,
'show_utterances': self.show_utterances,
'result_type': self.result_type,
"sequence": 1
'nbest': params.get('nbest', self.nbest),
'workflow': params.get('workflow', self.workflow),
'show_language': params.get('show_language', self.show_language),
'show_utterances': params.get('show_utterances', self.show_utterances),
'result_type': params.get('result_type', self.result_type),
'sequence': params.get('sequence', 1)
},
'audio': {
'format': self.format,
'rate': self.rate,
'language': self.language,
'bits': self.bits,
'channel': self.channel,
'codec': self.codec
'format': params.get('format', self.format),
'rate': params.get('rate', self.rate),
'language': params.get('language', self.language),
'bits': params.get('bits', self.bits),
'channel': params.get('channel', self.channel),
'codec': params.get('codec', self.codec)
}
}
return req

View File

@ -6,10 +6,28 @@ from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiSTTModelParams(BaseForm):
language = forms.TextInputField(
TooltipLabel(_('language'), _('If not passed, the default value is zh_cn')),
required=True,
default_value='zh_cn'
)
domain = forms.TextInputField(
TooltipLabel(_('domain'), _('If not passed, the default value is iat')),
required=True,
default_value='iat'
)
accent = forms.TextInputField(
TooltipLabel(_('accent'), _('If not passed, the default value is mandarin')),
required=True,
default_value='mandarin'
)
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat')
spark_app_id = forms.TextInputField('APP ID', required=True)
@ -30,7 +48,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
@ -48,4 +66,4 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
pass
return XunFeiSTTModelParams()

View File

@ -34,6 +34,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_api_key: str
spark_api_secret: str
spark_api_url: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -41,6 +42,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params')
@staticmethod
def is_cache_model():
@ -58,6 +60,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
params=model_kwargs,
**optional_params
)
@ -132,6 +135,11 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
frameSize = 8000 # 每一帧的音频大小
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt',
'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'}
business_params = {k: v for k,v in self.params.items() if k in allowed_params}
while True:
buf = file.read(frameSize)
# 文件结束
@ -144,17 +152,14 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
d = {
"common": {"app_id": self.spark_app_id},
"business": {
"domain": "iat",
"language": "zh_cn",
"accent": "mandarin",
"vinfo": 1,
"vad_eos": 10000
**business_params
},
"data": {
"status": 0, "format": "audio/L16;rate=16000",
"audio": str(base64.b64encode(buf), 'utf-8'),
"encoding": "lame"}
}
print(d)
d = json.dumps(d)
await ws.send(d)
status = STATUS_CONTINUE_FRAME

View File

@ -22,6 +22,8 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def is_cache_model():
@ -57,5 +59,14 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh',
**filter_params
}
res = client.audio.transcriptions.create(**transcription_params)
return res.text