diff --git a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py index 8577a1d5f..719e4201e 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py @@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer): audio_list = serializers.ListField(required=True, label=_("The audio file cannot be empty")) + model_params_setting = serializers.DictField(required=False, + label=_("Model parameter settings")) class ISpeechToTextNode(INode): @@ -35,6 +37,6 @@ class ISpeechToTextNode(INode): return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, stt_model_id, chat_id, - audio, + audio, model_params_setting=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index 7912873c9..613599d0a 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -20,9 +20,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode): if self.node_params.get('is_result', False): self.answer_text = details.get('answer') - def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: + def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult: workspace_id = self.workflow_manage.get_body().get('workspace_id') - stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id) + stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting) audio_list = audio self.context['audio_list'] = audio diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index 0a0914b5f..59e5e8df5 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -965,7 +965,7 @@ class ApplicationOperateSerializer(serializers.Serializer): application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by( '-create_time').first() if application.stt_model_enable: - model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id) + model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting) text = model.speech_to_text(instance.get('file')) return text diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index 48ec142a6..0db4c4b8d 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image" msgstr "" msgid "Update auth setting" +msgstr "" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "" + +msgid "If not passed, the default value is 16000" +msgstr "" + +msgid "Sample Rate" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 62885235e..d6a4d06fe 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image" msgstr "获取图片失败" msgid "Update auth setting" -msgstr "更新认证设置" \ No newline at end of file +msgstr "更新认证设置" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "如果未传入,则默认值为 streaming_asr_demo" + +msgid "If not passed, the default value is 16000" +msgstr "如果未传入,则默认值为 16000" + +msgid "Sample Rate" +msgstr "采样率" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index d3bdca50a..9952bbc3f 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image" msgstr "獲取圖片失敗" msgid "Update auth setting" -msgstr "更新認證設置" \ No newline at end of file +msgstr "更新認證設置" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "如果未傳入,則預設值為 streaming_asr_demo" + +msgid "If not passed, the default value is 16000" +msgstr "如果未傳入,則預設值為 16000" + +msgid "Sample Rate" +msgstr "採樣率" \ No newline at end of file diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py index a071f66a7..a6ee93912 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py @@ -4,11 +4,21 @@ import traceback from typing import Dict, Any from django.utils.translation import gettext as _ + +from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm, PasswordInputField +from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class AliyunBaiLianSTTModelParams(BaseForm): + sample_rate = forms.SliderField( + TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')), + required=True, + default_value=16000, + _step=4000, _min=0, _max=20000,precision=0 + ) + class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): """ Credential class for the Aliyun BaiLian STT (Speech-to-Text) model. @@ -55,7 +65,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -89,4 +99,4 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): :param model_name: Name of the model. :return: Parameter setting form (not implemented). """ - pass + return AliyunBaiLianSTTModelParams() diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py index 7017caf79..ece41f3dd 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py @@ -13,11 +13,14 @@ from models_provider.impl.base_stt import BaseSpeechToText class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.model = kwargs.get('model') + self.params = kwargs.get('params') + @staticmethod def is_cache_model(): @@ -33,6 +36,7 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): return AliyunBaiLianSpeechToText( model=model_name, api_key=model_credential.get('api_key'), + params=model_kwargs, **optional_params, ) @@ -43,10 +47,17 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): def speech_to_text(self, audio_file): dashscope.api_key = self.api_key - recognition = Recognition(model=self.model, - format='mp3', - sample_rate=16000, - callback=None) + recognition_params = { + 'model': self.model, + 'format': 'mp3', + 'sample_rate': 16000, + 'callback': None, + **self.params + } + print(recognition_params) + recognition = Recognition(**recognition_params) + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: # 将上传的文件保存到临时文件中 temp_file.write(audio_file.read()) diff --git a/apps/models_provider/impl/azure_model_provider/credential/stt.py b/apps/models_provider/impl/azure_model_provider/credential/stt.py index 715078334..cd115473f 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/stt.py +++ b/apps/models_provider/impl/azure_model_provider/credential/stt.py @@ -29,7 +29,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/azure_model_provider/model/stt.py b/apps/models_provider/impl/azure_model_provider/model/stt.py index 53f82e72f..c6364f373 100644 --- a/apps/models_provider/impl/azure_model_provider/model/stt.py +++ b/apps/models_provider/impl/azure_model_provider/model/stt.py @@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key: str api_version: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') self.api_version = kwargs.get('api_version') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -41,6 +43,7 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), api_version=model_credential.get('api_version'), + params=model_kwargs, **optional_params, ) @@ -62,5 +65,13 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/openai_model_provider/credential/stt.py b/apps/models_provider/impl/openai_model_provider/credential/stt.py index 6a1dd8474..b70785bc6 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/stt.py +++ b/apps/models_provider/impl/openai_model_provider/credential/stt.py @@ -6,10 +6,17 @@ from django.utils.translation import gettext as _ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAISTTModelParams(BaseForm): + language = forms.TextInputField( + TooltipLabel(_('language'), _('If not passed, the default value is zh')), + required=True, + default_value='zh', + ) + class OpenAISTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=True) @@ -28,7 +35,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -46,4 +53,5 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + + return OpenAISTTModelParams() diff --git a/apps/models_provider/impl/openai_model_provider/model/stt.py b/apps/models_provider/impl/openai_model_provider/model/stt.py index 6df1dff0a..329998556 100644 --- a/apps/models_provider/impl/openai_model_provider/model/stt.py +++ b/apps/models_provider/impl/openai_model_provider/model/stt.py @@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): api_base: str api_key: str model: str + params: dict @staticmethod def is_cache_model(): @@ -27,6 +28,8 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): @@ -39,6 +42,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): model=model_name, api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), + params = model_kwargs, **optional_params, ) @@ -58,6 +62,14 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py index 6ce4e8791..13e9cbe0e 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py @@ -28,7 +28,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py b/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py index c946ed39c..b5eb10128 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py @@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_base: str api_key: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): @@ -35,6 +37,7 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText): model=model_name, api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), + params=model_kwargs, **optional_params, ) @@ -58,5 +61,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py index f57c046e1..922d934a8 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py +++ b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py @@ -53,11 +53,14 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText): base_url=base_url ) + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': audio_file, + 'language': 'zh', + } result = client.audio.transcriptions.create( - file=audio_file, - model=self.model, - language=self.params.get('Language'), - response_format="json" + **transcription_params, extra_body=filter_params ) return result.text diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py index f7e9ecc87..12c18325f 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -6,9 +6,17 @@ from django.utils.translation import gettext as _ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class VolcanicEngineSTTModelParams(BaseForm): + uid = forms.TextInputField( + TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')), + required=True, + default_value='streaming_asr_demo' + ) + + class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): volcanic_api_url = forms.TextInputField('API URL', required=True, @@ -31,7 +39,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -49,4 +57,4 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} def get_model_params_setting_form(self, model_name): - pass + return VolcanicEngineSTTModelParams() diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py index 5b1791952..bc0e5128f 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -192,6 +192,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): volcanic_cluster: str volcanic_api_url: str volcanic_token: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -199,6 +200,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): self.volcanic_token = kwargs.get('volcanic_token') self.volcanic_app_id = kwargs.get('volcanic_app_id') self.volcanic_cluster = kwargs.get('volcanic_cluster') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -216,10 +218,14 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): volcanic_token=model_credential.get('volcanic_token'), volcanic_app_id=model_credential.get('volcanic_app_id'), volcanic_cluster=model_credential.get('volcanic_cluster'), + params=model_kwargs, + **model_kwargs, **optional_params ) def construct_request(self, reqid): + + params = self.params or {} req = { 'app': { 'appid': self.volcanic_app_id, @@ -227,24 +233,24 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): 'token': self.volcanic_token, }, 'user': { - 'uid': 'uid' + 'uid': params.get("uid", "streaming_asr_demo") }, 'request': { 'reqid': reqid, - 'nbest': self.nbest, - 'workflow': self.workflow, - 'show_language': self.show_language, - 'show_utterances': self.show_utterances, - 'result_type': self.result_type, - "sequence": 1 + 'nbest': params.get('nbest', self.nbest), + 'workflow': params.get('workflow', self.workflow), + 'show_language': params.get('show_language', self.show_language), + 'show_utterances': params.get('show_utterances', self.show_utterances), + 'result_type': params.get('result_type', self.result_type), + 'sequence': params.get('sequence', 1) }, 'audio': { - 'format': self.format, - 'rate': self.rate, - 'language': self.language, - 'bits': self.bits, - 'channel': self.channel, - 'codec': self.codec + 'format': params.get('format', self.format), + 'rate': params.get('rate', self.rate), + 'language': params.get('language', self.language), + 'bits': params.get('bits', self.bits), + 'channel': params.get('channel', self.channel), + 'codec': params.get('codec', self.codec) } } return req diff --git a/apps/models_provider/impl/xf_model_provider/credential/stt.py b/apps/models_provider/impl/xf_model_provider/credential/stt.py index 56d697b36..67da706ba 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/stt.py +++ b/apps/models_provider/impl/xf_model_provider/credential/stt.py @@ -6,10 +6,28 @@ from django.utils.translation import gettext as _ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class XunFeiSTTModelParams(BaseForm): + language = forms.TextInputField( + TooltipLabel(_('language'), _('If not passed, the default value is zh_cn')), + required=True, + default_value='zh_cn' + ) + domain = forms.TextInputField( + TooltipLabel(_('domain'), _('If not passed, the default value is iat')), + required=True, + default_value='iat' + ) + accent = forms.TextInputField( + TooltipLabel(_('accent'), _('If not passed, the default value is mandarin')), + required=True, + default_value='mandarin' + ) + + class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') spark_app_id = forms.TextInputField('APP ID', required=True) @@ -30,7 +48,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -48,4 +66,4 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} def get_model_params_setting_form(self, model_name): - pass + return XunFeiSTTModelParams() diff --git a/apps/models_provider/impl/xf_model_provider/model/stt.py b/apps/models_provider/impl/xf_model_provider/model/stt.py index 09f011f57..b43320746 100644 --- a/apps/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/models_provider/impl/xf_model_provider/model/stt.py @@ -34,6 +34,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_api_key: str spark_api_secret: str spark_api_url: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -41,6 +42,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): self.spark_app_id = kwargs.get('spark_app_id') self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_secret = kwargs.get('spark_api_secret') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -58,6 +60,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), spark_api_url=model_credential.get('spark_api_url'), + params=model_kwargs, **optional_params ) @@ -132,6 +135,11 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): frameSize = 8000 # 每一帧的音频大小 status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt', + 'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'} + + business_params = {k: v for k,v in self.params.items() if k in allowed_params} + while True: buf = file.read(frameSize) # 文件结束 @@ -144,17 +152,14 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): d = { "common": {"app_id": self.spark_app_id}, "business": { - "domain": "iat", - "language": "zh_cn", - "accent": "mandarin", - "vinfo": 1, - "vad_eos": 10000 + **business_params }, "data": { "status": 0, "format": "audio/L16;rate=16000", "audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "lame"} } + print(d) d = json.dumps(d) await ws.send(d) status = STATUS_CONTINUE_FRAME diff --git a/apps/models_provider/impl/xinference_model_provider/model/stt.py b/apps/models_provider/impl/xinference_model_provider/model/stt.py index 1994cd8fd..e614d42d4 100644 --- a/apps/models_provider/impl/xinference_model_provider/model/stt.py +++ b/apps/models_provider/impl/xinference_model_provider/model/stt.py @@ -22,6 +22,8 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -57,5 +59,14 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh', + **filter_params + } + + res = client.audio.transcriptions.create(**transcription_params) return res.text