feat: STT model params

v3.2
zhangzhanwei 2025-09-25 15:45:23 +08:00 committed by zhanweizhang7
parent a2130edd08
commit e8c36a6223
20 changed files with 193 additions and 50 deletions

View File

@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer):
audio_list = serializers.ListField(required=True, audio_list = serializers.ListField(required=True,
label=_("The audio file cannot be empty")) label=_("The audio file cannot be empty"))
model_params_setting = serializers.DictField(required=False,
label=_("Model parameter settings"))
class ISpeechToTextNode(INode): class ISpeechToTextNode(INode):
@ -35,6 +37,6 @@ class ISpeechToTextNode(INode):
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, stt_model_id, chat_id, def execute(self, stt_model_id, chat_id,
audio, audio, model_params_setting=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
pass pass

View File

@ -20,9 +20,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
if self.node_params.get('is_result', False): if self.node_params.get('is_result', False):
self.answer_text = details.get('answer') self.answer_text = details.get('answer')
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
workspace_id = self.workflow_manage.get_body().get('workspace_id') workspace_id = self.workflow_manage.get_body().get('workspace_id')
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id) stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
audio_list = audio audio_list = audio
self.context['audio_list'] = audio self.context['audio_list'] = audio

View File

@ -965,7 +965,7 @@ class ApplicationOperateSerializer(serializers.Serializer):
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by( application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
'-create_time').first() '-create_time').first()
if application.stt_model_enable: if application.stt_model_enable:
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id) model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting)
text = model.speech_to_text(instance.get('file')) text = model.speech_to_text(instance.get('file'))
return text return text

View File

@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image"
msgstr "" msgstr ""
msgid "Update auth setting" msgid "Update auth setting"
msgstr ""
msgid "If not passed, the default value is streaming_asr_demo"
msgstr ""
msgid "If not passed, the default value is 16000"
msgstr ""
msgid "Sample Rate"
msgstr "" msgstr ""

View File

@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "获取图片失败" msgstr "获取图片失败"
msgid "Update auth setting" msgid "Update auth setting"
msgstr "更新认证设置" msgstr "更新认证设置"
msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未传入,则默认值为 streaming_asr_demo"
msgid "If not passed, the default value is 16000"
msgstr "如果未传入,则默认值为 16000"
msgid "Sample Rate"
msgstr "采样率"

View File

@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "獲取圖片失敗" msgstr "獲取圖片失敗"
msgid "Update auth setting" msgid "Update auth setting"
msgstr "更新認證設置" msgstr "更新認證設置"
msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未傳入,則預設值為 streaming_asr_demo"
msgid "If not passed, the default value is 16000"
msgstr "如果未傳入,則預設值為 16000"
msgid "Sample Rate"
msgstr "採樣率"

View File

@ -4,11 +4,21 @@ import traceback
from typing import Dict, Any from typing import Dict, Any
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AliyunBaiLianSTTModelParams(BaseForm):
sample_rate = forms.SliderField(
TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')),
required=True,
default_value=16000,
_step=4000, _min=0, _max=20000,precision=0
)
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
""" """
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model. Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
@ -55,7 +65,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -89,4 +99,4 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
:param model_name: Name of the model. :param model_name: Name of the model.
:return: Parameter setting form (not implemented). :return: Parameter setting form (not implemented).
""" """
pass return AliyunBaiLianSTTModelParams()

View File

@ -13,11 +13,14 @@ from models_provider.impl.base_stt import BaseSpeechToText
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str api_key: str
model: str model: str
params: dict
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = kwargs.get('api_key') self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model') self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -33,6 +36,7 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
return AliyunBaiLianSpeechToText( return AliyunBaiLianSpeechToText(
model=model_name, model=model_name,
api_key=model_credential.get('api_key'), api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params, **optional_params,
) )
@ -43,10 +47,17 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
def speech_to_text(self, audio_file): def speech_to_text(self, audio_file):
dashscope.api_key = self.api_key dashscope.api_key = self.api_key
recognition = Recognition(model=self.model, recognition_params = {
format='mp3', 'model': self.model,
sample_rate=16000, 'format': 'mp3',
callback=None) 'sample_rate': 16000,
'callback': None,
**self.params
}
print(recognition_params)
recognition = Recognition(**recognition_params)
with tempfile.NamedTemporaryFile(delete=False) as temp_file: with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中 # 将上传的文件保存到临时文件中
temp_file.write(audio_file.read()) temp_file.write(audio_file.read())

View File

@ -29,7 +29,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
else: else:
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View File

@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str api_key: str
api_version: str api_version: str
model: str model: str
params: dict
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = kwargs.get('api_key') self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base') self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version') self.api_version = kwargs.get('api_version')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -41,6 +43,7 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base=model_credential.get('api_base'), api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'), api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'), api_version=model_credential.get('api_version'),
params=model_kwargs,
**optional_params, **optional_params,
) )
@ -62,5 +65,13 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read() audio_data = audio_file.read()
buffer = io.BytesIO(audio_data) buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params)
return res.text return res.text

View File

@ -6,10 +6,17 @@ from django.utils.translation import gettext as _
from common import forms from common import forms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.forms import BaseForm from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAISTTModelParams(BaseForm):
language = forms.TextInputField(
TooltipLabel(_('language'), _('If not passed, the default value is zh')),
required=True,
default_value='zh',
)
class OpenAISTTModelCredential(BaseForm, BaseModelCredential): class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API URL', required=True) api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True) api_key = forms.PasswordInputField('API Key', required=True)
@ -28,7 +35,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
else: else:
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -46,4 +53,5 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name): def get_model_params_setting_form(self, model_name):
pass
return OpenAISTTModelParams()

View File

@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str api_base: str
api_key: str api_key: str
model: str model: str
params: dict
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -27,6 +28,8 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = kwargs.get('api_key') self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base') self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
@ -39,6 +42,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
model=model_name, model=model_name,
api_base=model_credential.get('api_base'), api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'), api_key=model_credential.get('api_key'),
params = model_kwargs,
**optional_params, **optional_params,
) )
@ -58,6 +62,14 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read() audio_data = audio_file.read()
buffer = io.BytesIO(audio_data) buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text return res.text

View File

@ -28,7 +28,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
else: else:
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View File

@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str api_base: str
api_key: str api_key: str
model: str model: str
params: dict
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = kwargs.get('api_key') self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base') self.api_base = kwargs.get('api_base')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
@ -35,6 +37,7 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
model=model_name, model=model_name,
api_base=model_credential.get('api_base'), api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'), api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params, **optional_params,
) )
@ -58,5 +61,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read() audio_data = audio_file.read()
buffer = io.BytesIO(audio_data) buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text return res.text

View File

@ -53,11 +53,14 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
base_url=base_url base_url=base_url
) )
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': audio_file,
'language': 'zh',
}
result = client.audio.transcriptions.create( result = client.audio.transcriptions.create(
file=audio_file, **transcription_params, extra_body=filter_params
model=self.model,
language=self.params.get('Language'),
response_format="json"
) )
return result.text return result.text

View File

@ -6,9 +6,17 @@ from django.utils.translation import gettext as _
from common import forms from common import forms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.forms import BaseForm from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineSTTModelParams(BaseForm):
uid = forms.TextInputField(
TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')),
required=True,
default_value='streaming_asr_demo'
)
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API URL', required=True, volcanic_api_url = forms.TextInputField('API URL', required=True,
@ -31,7 +39,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
else: else:
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -49,4 +57,4 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
def get_model_params_setting_form(self, model_name): def get_model_params_setting_form(self, model_name):
pass return VolcanicEngineSTTModelParams()

View File

@ -192,6 +192,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
volcanic_cluster: str volcanic_cluster: str
volcanic_api_url: str volcanic_api_url: str
volcanic_token: str volcanic_token: str
params: dict
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -199,6 +200,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
self.volcanic_token = kwargs.get('volcanic_token') self.volcanic_token = kwargs.get('volcanic_token')
self.volcanic_app_id = kwargs.get('volcanic_app_id') self.volcanic_app_id = kwargs.get('volcanic_app_id')
self.volcanic_cluster = kwargs.get('volcanic_cluster') self.volcanic_cluster = kwargs.get('volcanic_cluster')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -216,10 +218,14 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
volcanic_token=model_credential.get('volcanic_token'), volcanic_token=model_credential.get('volcanic_token'),
volcanic_app_id=model_credential.get('volcanic_app_id'), volcanic_app_id=model_credential.get('volcanic_app_id'),
volcanic_cluster=model_credential.get('volcanic_cluster'), volcanic_cluster=model_credential.get('volcanic_cluster'),
params=model_kwargs,
**model_kwargs,
**optional_params **optional_params
) )
def construct_request(self, reqid): def construct_request(self, reqid):
params = self.params or {}
req = { req = {
'app': { 'app': {
'appid': self.volcanic_app_id, 'appid': self.volcanic_app_id,
@ -227,24 +233,24 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
'token': self.volcanic_token, 'token': self.volcanic_token,
}, },
'user': { 'user': {
'uid': 'uid' 'uid': params.get("uid", "streaming_asr_demo")
}, },
'request': { 'request': {
'reqid': reqid, 'reqid': reqid,
'nbest': self.nbest, 'nbest': params.get('nbest', self.nbest),
'workflow': self.workflow, 'workflow': params.get('workflow', self.workflow),
'show_language': self.show_language, 'show_language': params.get('show_language', self.show_language),
'show_utterances': self.show_utterances, 'show_utterances': params.get('show_utterances', self.show_utterances),
'result_type': self.result_type, 'result_type': params.get('result_type', self.result_type),
"sequence": 1 'sequence': params.get('sequence', 1)
}, },
'audio': { 'audio': {
'format': self.format, 'format': params.get('format', self.format),
'rate': self.rate, 'rate': params.get('rate', self.rate),
'language': self.language, 'language': params.get('language', self.language),
'bits': self.bits, 'bits': params.get('bits', self.bits),
'channel': self.channel, 'channel': params.get('channel', self.channel),
'codec': self.codec 'codec': params.get('codec', self.codec)
} }
} }
return req return req

View File

@ -6,10 +6,28 @@ from django.utils.translation import gettext as _
from common import forms from common import forms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.forms import BaseForm from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiSTTModelParams(BaseForm):
language = forms.TextInputField(
TooltipLabel(_('language'), _('If not passed, the default value is zh_cn')),
required=True,
default_value='zh_cn'
)
domain = forms.TextInputField(
TooltipLabel(_('domain'), _('If not passed, the default value is iat')),
required=True,
default_value='iat'
)
accent = forms.TextInputField(
TooltipLabel(_('accent'), _('If not passed, the default value is mandarin')),
required=True,
default_value='mandarin'
)
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat')
spark_app_id = forms.TextInputField('APP ID', required=True) spark_app_id = forms.TextInputField('APP ID', required=True)
@ -30,7 +48,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
else: else:
return False return False
try: try:
model = provider.get_model(model_type, model_name, model_credential) model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth() model.check_auth()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -48,4 +66,4 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name): def get_model_params_setting_form(self, model_name):
pass return XunFeiSTTModelParams()

View File

@ -34,6 +34,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_api_key: str spark_api_key: str
spark_api_secret: str spark_api_secret: str
spark_api_url: str spark_api_url: str
params: dict
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -41,6 +42,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
self.spark_app_id = kwargs.get('spark_app_id') self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret') self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -58,6 +60,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_api_key=model_credential.get('spark_api_key'), spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'), spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'), spark_api_url=model_credential.get('spark_api_url'),
params=model_kwargs,
**optional_params **optional_params
) )
@ -132,6 +135,11 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
frameSize = 8000 # 每一帧的音频大小 frameSize = 8000 # 每一帧的音频大小
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt',
'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'}
business_params = {k: v for k,v in self.params.items() if k in allowed_params}
while True: while True:
buf = file.read(frameSize) buf = file.read(frameSize)
# 文件结束 # 文件结束
@ -144,17 +152,14 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
d = { d = {
"common": {"app_id": self.spark_app_id}, "common": {"app_id": self.spark_app_id},
"business": { "business": {
"domain": "iat", **business_params
"language": "zh_cn",
"accent": "mandarin",
"vinfo": 1,
"vad_eos": 10000
}, },
"data": { "data": {
"status": 0, "format": "audio/L16;rate=16000", "status": 0, "format": "audio/L16;rate=16000",
"audio": str(base64.b64encode(buf), 'utf-8'), "audio": str(base64.b64encode(buf), 'utf-8'),
"encoding": "lame"} "encoding": "lame"}
} }
print(d)
d = json.dumps(d) d = json.dumps(d)
await ws.send(d) await ws.send(d)
status = STATUS_CONTINUE_FRAME status = STATUS_CONTINUE_FRAME

View File

@ -22,6 +22,8 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = kwargs.get('api_key') self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base') self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod @staticmethod
def is_cache_model(): def is_cache_model():
@ -57,5 +59,14 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
audio_data = audio_file.read() audio_data = audio_file.read()
buffer = io.BytesIO(audio_data) buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh',
**filter_params
}
res = client.audio.transcriptions.create(**transcription_params)
return res.text return res.text