feat: STT model params
parent
a2130edd08
commit
e8c36a6223
|
|
@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer):
|
||||||
|
|
||||||
audio_list = serializers.ListField(required=True,
|
audio_list = serializers.ListField(required=True,
|
||||||
label=_("The audio file cannot be empty"))
|
label=_("The audio file cannot be empty"))
|
||||||
|
model_params_setting = serializers.DictField(required=False,
|
||||||
|
label=_("Model parameter settings"))
|
||||||
|
|
||||||
|
|
||||||
class ISpeechToTextNode(INode):
|
class ISpeechToTextNode(INode):
|
||||||
|
|
@ -35,6 +37,6 @@ class ISpeechToTextNode(INode):
|
||||||
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
|
|
||||||
def execute(self, stt_model_id, chat_id,
|
def execute(self, stt_model_id, chat_id,
|
||||||
audio,
|
audio, model_params_setting=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
|
||||||
if self.node_params.get('is_result', False):
|
if self.node_params.get('is_result', False):
|
||||||
self.answer_text = details.get('answer')
|
self.answer_text = details.get('answer')
|
||||||
|
|
||||||
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
|
def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||||
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
|
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
|
||||||
audio_list = audio
|
audio_list = audio
|
||||||
self.context['audio_list'] = audio
|
self.context['audio_list'] = audio
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -965,7 +965,7 @@ class ApplicationOperateSerializer(serializers.Serializer):
|
||||||
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
|
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
|
||||||
'-create_time').first()
|
'-create_time').first()
|
||||||
if application.stt_model_enable:
|
if application.stt_model_enable:
|
||||||
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id)
|
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting)
|
||||||
text = model.speech_to_text(instance.get('file'))
|
text = model.speech_to_text(instance.get('file'))
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
||||||
msgid "Update auth setting"
|
msgid "Update auth setting"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is streaming_asr_demo"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is 16000"
|
||||||
|
msgstr ""
|
||||||
|
|
||||||
|
msgid "Sample Rate"
|
||||||
msgstr ""
|
msgstr ""
|
||||||
|
|
@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
|
||||||
msgstr "获取图片失败"
|
msgstr "获取图片失败"
|
||||||
|
|
||||||
msgid "Update auth setting"
|
msgid "Update auth setting"
|
||||||
msgstr "更新认证设置"
|
msgstr "更新认证设置"
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is streaming_asr_demo"
|
||||||
|
msgstr "如果未传入,则默认值为 streaming_asr_demo"
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is 16000"
|
||||||
|
msgstr "如果未传入,则默认值为 16000"
|
||||||
|
|
||||||
|
msgid "Sample Rate"
|
||||||
|
msgstr "采样率"
|
||||||
|
|
@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
|
||||||
msgstr "獲取圖片失敗"
|
msgstr "獲取圖片失敗"
|
||||||
|
|
||||||
msgid "Update auth setting"
|
msgid "Update auth setting"
|
||||||
msgstr "更新認證設置"
|
msgstr "更新認證設置"
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is streaming_asr_demo"
|
||||||
|
msgstr "如果未傳入,則預設值為 streaming_asr_demo"
|
||||||
|
|
||||||
|
msgid "If not passed, the default value is 16000"
|
||||||
|
msgstr "如果未傳入,則預設值為 16000"
|
||||||
|
|
||||||
|
msgid "Sample Rate"
|
||||||
|
msgstr "採樣率"
|
||||||
|
|
@ -4,11 +4,21 @@ import traceback
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm, PasswordInputField
|
from common.forms import BaseForm, PasswordInputField, TooltipLabel
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class AliyunBaiLianSTTModelParams(BaseForm):
|
||||||
|
sample_rate = forms.SliderField(
|
||||||
|
TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')),
|
||||||
|
required=True,
|
||||||
|
default_value=16000,
|
||||||
|
_step=4000, _min=0, _max=20000,precision=0
|
||||||
|
)
|
||||||
|
|
||||||
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
"""
|
"""
|
||||||
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
|
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
|
||||||
|
|
@ -55,7 +65,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential,**model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
@ -89,4 +99,4 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
:param model_name: Name of the model.
|
:param model_name: Name of the model.
|
||||||
:return: Parameter setting form (not implemented).
|
:return: Parameter setting form (not implemented).
|
||||||
"""
|
"""
|
||||||
pass
|
return AliyunBaiLianSTTModelParams()
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,14 @@ from models_provider.impl.base_stt import BaseSpeechToText
|
||||||
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
api_key: str
|
api_key: str
|
||||||
model: str
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = kwargs.get('api_key')
|
self.api_key = kwargs.get('api_key')
|
||||||
self.model = kwargs.get('model')
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -33,6 +36,7 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
return AliyunBaiLianSpeechToText(
|
return AliyunBaiLianSpeechToText(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_key=model_credential.get('api_key'),
|
api_key=model_credential.get('api_key'),
|
||||||
|
params=model_kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -43,10 +47,17 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
|
|
||||||
def speech_to_text(self, audio_file):
|
def speech_to_text(self, audio_file):
|
||||||
dashscope.api_key = self.api_key
|
dashscope.api_key = self.api_key
|
||||||
recognition = Recognition(model=self.model,
|
recognition_params = {
|
||||||
format='mp3',
|
'model': self.model,
|
||||||
sample_rate=16000,
|
'format': 'mp3',
|
||||||
callback=None)
|
'sample_rate': 16000,
|
||||||
|
'callback': None,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
print(recognition_params)
|
||||||
|
recognition = Recognition(**recognition_params)
|
||||||
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
# 将上传的文件保存到临时文件中
|
# 将上传的文件保存到临时文件中
|
||||||
temp_file.write(audio_file.read())
|
temp_file.write(audio_file.read())
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
api_key: str
|
api_key: str
|
||||||
api_version: str
|
api_version: str
|
||||||
model: str
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = kwargs.get('api_key')
|
self.api_key = kwargs.get('api_key')
|
||||||
self.api_base = kwargs.get('api_base')
|
self.api_base = kwargs.get('api_base')
|
||||||
self.api_version = kwargs.get('api_version')
|
self.api_version = kwargs.get('api_version')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -41,6 +43,7 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
api_base=model_credential.get('api_base'),
|
api_base=model_credential.get('api_base'),
|
||||||
api_key=model_credential.get('api_key'),
|
api_key=model_credential.get('api_key'),
|
||||||
api_version=model_credential.get('api_version'),
|
api_version=model_credential.get('api_version'),
|
||||||
|
params=model_kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -62,5 +65,13 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
audio_data = audio_file.read()
|
audio_data = audio_file.read()
|
||||||
buffer = io.BytesIO(audio_data)
|
buffer = io.BytesIO(audio_data)
|
||||||
buffer.name = "file.mp3" # this is the important line
|
buffer.name = "file.mp3" # this is the important line
|
||||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
|
|
||||||
|
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
|
||||||
|
transcription_params = {
|
||||||
|
'model': self.model,
|
||||||
|
'file': buffer,
|
||||||
|
'language': 'zh'
|
||||||
|
}
|
||||||
|
|
||||||
|
res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params)
|
||||||
return res.text
|
return res.text
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,17 @@ from django.utils.translation import gettext as _
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm, TooltipLabel
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAISTTModelParams(BaseForm):
|
||||||
|
language = forms.TextInputField(
|
||||||
|
TooltipLabel(_('language'), _('If not passed, the default value is zh')),
|
||||||
|
required=True,
|
||||||
|
default_value='zh',
|
||||||
|
)
|
||||||
|
|
||||||
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
@ -28,7 +35,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
@ -46,4 +53,5 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
|
||||||
|
return OpenAISTTModelParams()
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
api_base: str
|
api_base: str
|
||||||
api_key: str
|
api_key: str
|
||||||
model: str
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -27,6 +28,8 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = kwargs.get('api_key')
|
self.api_key = kwargs.get('api_key')
|
||||||
self.api_base = kwargs.get('api_base')
|
self.api_base = kwargs.get('api_base')
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
|
@ -39,6 +42,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_base=model_credential.get('api_base'),
|
api_base=model_credential.get('api_base'),
|
||||||
api_key=model_credential.get('api_key'),
|
api_key=model_credential.get('api_key'),
|
||||||
|
params = model_kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -58,6 +62,14 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
audio_data = audio_file.read()
|
audio_data = audio_file.read()
|
||||||
buffer = io.BytesIO(audio_data)
|
buffer = io.BytesIO(audio_data)
|
||||||
buffer.name = "file.mp3" # this is the important line
|
buffer.name = "file.mp3" # this is the important line
|
||||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
|
|
||||||
|
filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}}
|
||||||
|
transcription_params = {
|
||||||
|
'model': self.model,
|
||||||
|
'file': buffer,
|
||||||
|
'language': 'zh'
|
||||||
|
}
|
||||||
|
|
||||||
|
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
|
||||||
return res.text
|
return res.text
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential,**model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
api_base: str
|
api_base: str
|
||||||
api_key: str
|
api_key: str
|
||||||
model: str
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = kwargs.get('api_key')
|
self.api_key = kwargs.get('api_key')
|
||||||
self.api_base = kwargs.get('api_base')
|
self.api_base = kwargs.get('api_base')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
|
@ -35,6 +37,7 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
model=model_name,
|
model=model_name,
|
||||||
api_base=model_credential.get('api_base'),
|
api_base=model_credential.get('api_base'),
|
||||||
api_key=model_credential.get('api_key'),
|
api_key=model_credential.get('api_key'),
|
||||||
|
params=model_kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -58,5 +61,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
audio_data = audio_file.read()
|
audio_data = audio_file.read()
|
||||||
buffer = io.BytesIO(audio_data)
|
buffer = io.BytesIO(audio_data)
|
||||||
buffer.name = "file.mp3" # this is the important line
|
buffer.name = "file.mp3" # this is the important line
|
||||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
|
|
||||||
|
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
|
||||||
|
transcription_params = {
|
||||||
|
'model': self.model,
|
||||||
|
'file': buffer,
|
||||||
|
'language': 'zh'
|
||||||
|
}
|
||||||
|
|
||||||
|
res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
|
||||||
return res.text
|
return res.text
|
||||||
|
|
|
||||||
|
|
@ -53,11 +53,14 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
base_url=base_url
|
base_url=base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
|
||||||
|
transcription_params = {
|
||||||
|
'model': self.model,
|
||||||
|
'file': audio_file,
|
||||||
|
'language': 'zh',
|
||||||
|
}
|
||||||
result = client.audio.transcriptions.create(
|
result = client.audio.transcriptions.create(
|
||||||
file=audio_file,
|
**transcription_params, extra_body=filter_params
|
||||||
model=self.model,
|
|
||||||
language=self.params.get('Language'),
|
|
||||||
response_format="json"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.text
|
return result.text
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,17 @@ from django.utils.translation import gettext as _
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm, TooltipLabel
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
class VolcanicEngineSTTModelParams(BaseForm):
|
||||||
|
uid = forms.TextInputField(
|
||||||
|
TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')),
|
||||||
|
required=True,
|
||||||
|
default_value='streaming_asr_demo'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
volcanic_api_url = forms.TextInputField('API URL', required=True,
|
volcanic_api_url = forms.TextInputField('API URL', required=True,
|
||||||
|
|
@ -31,7 +39,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
@ -49,4 +57,4 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return VolcanicEngineSTTModelParams()
|
||||||
|
|
|
||||||
|
|
@ -192,6 +192,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
volcanic_cluster: str
|
volcanic_cluster: str
|
||||||
volcanic_api_url: str
|
volcanic_api_url: str
|
||||||
volcanic_token: str
|
volcanic_token: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
@ -199,6 +200,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
self.volcanic_token = kwargs.get('volcanic_token')
|
self.volcanic_token = kwargs.get('volcanic_token')
|
||||||
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
||||||
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -216,10 +218,14 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
volcanic_token=model_credential.get('volcanic_token'),
|
volcanic_token=model_credential.get('volcanic_token'),
|
||||||
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
||||||
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
||||||
|
params=model_kwargs,
|
||||||
|
**model_kwargs,
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def construct_request(self, reqid):
|
def construct_request(self, reqid):
|
||||||
|
|
||||||
|
params = self.params or {}
|
||||||
req = {
|
req = {
|
||||||
'app': {
|
'app': {
|
||||||
'appid': self.volcanic_app_id,
|
'appid': self.volcanic_app_id,
|
||||||
|
|
@ -227,24 +233,24 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
'token': self.volcanic_token,
|
'token': self.volcanic_token,
|
||||||
},
|
},
|
||||||
'user': {
|
'user': {
|
||||||
'uid': 'uid'
|
'uid': params.get("uid", "streaming_asr_demo")
|
||||||
},
|
},
|
||||||
'request': {
|
'request': {
|
||||||
'reqid': reqid,
|
'reqid': reqid,
|
||||||
'nbest': self.nbest,
|
'nbest': params.get('nbest', self.nbest),
|
||||||
'workflow': self.workflow,
|
'workflow': params.get('workflow', self.workflow),
|
||||||
'show_language': self.show_language,
|
'show_language': params.get('show_language', self.show_language),
|
||||||
'show_utterances': self.show_utterances,
|
'show_utterances': params.get('show_utterances', self.show_utterances),
|
||||||
'result_type': self.result_type,
|
'result_type': params.get('result_type', self.result_type),
|
||||||
"sequence": 1
|
'sequence': params.get('sequence', 1)
|
||||||
},
|
},
|
||||||
'audio': {
|
'audio': {
|
||||||
'format': self.format,
|
'format': params.get('format', self.format),
|
||||||
'rate': self.rate,
|
'rate': params.get('rate', self.rate),
|
||||||
'language': self.language,
|
'language': params.get('language', self.language),
|
||||||
'bits': self.bits,
|
'bits': params.get('bits', self.bits),
|
||||||
'channel': self.channel,
|
'channel': params.get('channel', self.channel),
|
||||||
'codec': self.codec
|
'codec': params.get('codec', self.codec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return req
|
return req
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,28 @@ from django.utils.translation import gettext as _
|
||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm, TooltipLabel
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XunFeiSTTModelParams(BaseForm):
|
||||||
|
language = forms.TextInputField(
|
||||||
|
TooltipLabel(_('language'), _('If not passed, the default value is zh_cn')),
|
||||||
|
required=True,
|
||||||
|
default_value='zh_cn'
|
||||||
|
)
|
||||||
|
domain = forms.TextInputField(
|
||||||
|
TooltipLabel(_('domain'), _('If not passed, the default value is iat')),
|
||||||
|
required=True,
|
||||||
|
default_value='iat'
|
||||||
|
)
|
||||||
|
accent = forms.TextInputField(
|
||||||
|
TooltipLabel(_('accent'), _('If not passed, the default value is mandarin')),
|
||||||
|
required=True,
|
||||||
|
default_value='mandarin'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat')
|
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat')
|
||||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
|
@ -30,7 +48,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
model = provider.get_model(model_type, model_name, model_credential)
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
model.check_auth()
|
model.check_auth()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
@ -48,4 +66,4 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return XunFeiSTTModelParams()
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
spark_api_key: str
|
spark_api_key: str
|
||||||
spark_api_secret: str
|
spark_api_secret: str
|
||||||
spark_api_url: str
|
spark_api_url: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
@ -41,6 +42,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
self.spark_app_id = kwargs.get('spark_app_id')
|
self.spark_app_id = kwargs.get('spark_app_id')
|
||||||
self.spark_api_key = kwargs.get('spark_api_key')
|
self.spark_api_key = kwargs.get('spark_api_key')
|
||||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -58,6 +60,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
spark_api_key=model_credential.get('spark_api_key'),
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
spark_api_url=model_credential.get('spark_api_url'),
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
params=model_kwargs,
|
||||||
**optional_params
|
**optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -132,6 +135,11 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
frameSize = 8000 # 每一帧的音频大小
|
frameSize = 8000 # 每一帧的音频大小
|
||||||
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||||||
|
|
||||||
|
allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt',
|
||||||
|
'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'}
|
||||||
|
|
||||||
|
business_params = {k: v for k,v in self.params.items() if k in allowed_params}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
buf = file.read(frameSize)
|
buf = file.read(frameSize)
|
||||||
# 文件结束
|
# 文件结束
|
||||||
|
|
@ -144,17 +152,14 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
d = {
|
d = {
|
||||||
"common": {"app_id": self.spark_app_id},
|
"common": {"app_id": self.spark_app_id},
|
||||||
"business": {
|
"business": {
|
||||||
"domain": "iat",
|
**business_params
|
||||||
"language": "zh_cn",
|
|
||||||
"accent": "mandarin",
|
|
||||||
"vinfo": 1,
|
|
||||||
"vad_eos": 10000
|
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"status": 0, "format": "audio/L16;rate=16000",
|
"status": 0, "format": "audio/L16;rate=16000",
|
||||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||||
"encoding": "lame"}
|
"encoding": "lame"}
|
||||||
}
|
}
|
||||||
|
print(d)
|
||||||
d = json.dumps(d)
|
d = json.dumps(d)
|
||||||
await ws.send(d)
|
await ws.send(d)
|
||||||
status = STATUS_CONTINUE_FRAME
|
status = STATUS_CONTINUE_FRAME
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = kwargs.get('api_key')
|
self.api_key = kwargs.get('api_key')
|
||||||
self.api_base = kwargs.get('api_base')
|
self.api_base = kwargs.get('api_base')
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
|
|
@ -57,5 +59,14 @@ class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
audio_data = audio_file.read()
|
audio_data = audio_file.read()
|
||||||
buffer = io.BytesIO(audio_data)
|
buffer = io.BytesIO(audio_data)
|
||||||
buffer.name = "file.mp3" # this is the important line
|
buffer.name = "file.mp3" # this is the important line
|
||||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
|
|
||||||
|
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
|
||||||
|
transcription_params = {
|
||||||
|
'model': self.model,
|
||||||
|
'file': buffer,
|
||||||
|
'language': 'zh',
|
||||||
|
**filter_params
|
||||||
|
}
|
||||||
|
|
||||||
|
res = client.audio.transcriptions.create(**transcription_params)
|
||||||
return res.text
|
return res.text
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue