From 35d94626897076bc0ddf98fb391c7427b81c4c6b Mon Sep 17 00:00:00 2001 From: wxg0103 <46886316+wxg0103@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:07:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0xinference=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=AF=B9=E6=8E=A5=20(#959)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../constants/model_provider_constants.py | 2 + .../aws_bedrock_model_provider.py | 59 +- .../models_provider/impl/base_chat_open_ai.py | 78 +++ .../impl/deepseek_model_provider/model/llm.py | 14 +- .../impl/kimi_model_provider/model/llm.py | 11 +- .../impl/ollama_model_provider/model/llm.py | 12 +- .../impl/openai_model_provider/model/llm.py | 20 +- .../impl/xf_model_provider/model/llm.py | 40 +- .../xinference_model_provider/__init__.py | 1 + .../credential/embedding.py | 38 ++ .../credential/llm.py | 41 ++ .../icon/xinference_icon_svg | 5 + .../model/embedding.py | 24 + .../xinference_model_provider/model/llm.py | 39 ++ .../xinference_model_provider.py | 528 ++++++++++++++++++ .../impl/zhipu_model_provider/model/llm.py | 84 ++- pyproject.toml | 1 + 17 files changed, 923 insertions(+), 74 deletions(-) create mode 100644 apps/setting/models_provider/impl/base_chat_open_ai.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/icon/xinference_icon_svg create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/model/llm.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 6d4ff9cc0..6f691d276 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -21,6 +21,7 @@ from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine VolcanicEngineModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider +from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider @@ -40,3 +41,4 @@ class ModelProvideConstants(Enum): model_tencent_provider = TencentModelProvider() model_aws_bedrock_provider = BedrockModelProvider() model_local_provider = LocalModelProvider() + model_xinference_provider = XinferenceModelProvider() diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py index a6187f995..a3a969564 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -31,13 +31,56 @@ def _get_aws_bedrock_icon_path(): def _initialize_model_info(): - model_info_list = [_create_model_info( - 'amazon.titan-text-premier-v1:0', - 'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。', - ModelTypeConst.LLM, - BedrockLLMModelCredential, - BedrockModel - ), + model_info_list = [ + _create_model_info( + 'anthropic.claude-v2:1', + 'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-v2', + 'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-haiku-20240307-v1:0', + 'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-sonnet-20240229-v1:0', + 'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + 'Claude 3.5 Sonnet提高了智能的行业标准,在广泛的评估中超越了竞争对手的型号和Claude 3 Opus,具有我们中端型号的速度和成本效益。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-instant-v1', + '一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'amazon.titan-text-premier-v1:0', + 'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), _create_model_info( 'amazon.titan-text-lite-v1', 'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型', @@ -59,7 +102,7 @@ def _initialize_model_info(): _create_model_info( 'mistral.mistral-7b-instruct-v0:2', '7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。', - ModelTypeConst.EMBEDDING, + ModelTypeConst.LLM, BedrockLLMModelCredential, BedrockModel), _create_model_info( diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py new file mode 100644 index 000000000..1774c75b1 --- /dev/null +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -0,0 +1,78 @@ +# coding=utf-8 + +from typing import List, Dict, Optional, Any, Iterator, Type +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _convert_delta_to_message_chunk + + +class BaseChatOpenAI(ChatOpenAI): + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.get_last_generation_info().get('prompt_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.get_last_generation_info().get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + payload = self._get_request_payload(messages, stop=stop, **kwargs) + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + base_generation_info = {} + with response: + is_first_chunk = True + for chunk in response: + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if len(chunk["choices"]) == 0: + if token_usage := chunk.get("usage"): + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + logprobs = None + else: + continue + else: + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if is_first_chunk else {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs + ) + is_first_chunk = False + yield generation_chunk diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 086b02e60..8bc6f30d2 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -8,14 +8,11 @@ """ from typing import List, Dict -from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai import ChatOpenAI - -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI): +class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): deepseek_chat_open_ai = DeepSeekChatModel( @@ -25,10 +22,3 @@ class DeepSeekChatModel(MaxKBBaseModel, ChatOpenAI): ) return deepseek_chat_open_ai - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index 3e4d4f282..652788cc5 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -8,14 +8,14 @@ """ from typing import List, Dict -from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class KimiChatModel(MaxKBBaseModel, ChatOpenAI): +class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): kimi_chat_open_ai = KimiChatModel( @@ -25,10 +25,3 @@ class KimiChatModel(MaxKBBaseModel, ChatOpenAI): ) return kimi_chat_open_ai - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index fb1e77cdd..2a21c31b9 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -9,11 +9,11 @@ from typing import List, Dict from urllib.parse import urlparse, ParseResult -from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI def get_base_url(url: str): @@ -24,7 +24,7 @@ def get_base_url(url: str): return result_url[:-1] if result_url.endswith("/") else result_url -class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): +class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_base = model_credential.get('api_base', '') @@ -32,11 +32,3 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key')) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 7ad5f49f9..9c3a9c116 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -8,27 +8,19 @@ """ from typing import List, Dict -from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai import ChatOpenAI - -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): +class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), - openai_api_key=model_credential.get('api_key') + openai_api_key=model_credential.get('api_key'), + streaming=model_kwargs.get('streaming', False), + max_tokens=model_kwargs.get('max_tokens', 5), + temperature=model_kwargs.get('temperature', 0.5), ) return azure_chat_open_ai - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index a58941286..1dccd29e3 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -6,16 +6,15 @@ @date:2024/04/19 15:55 @desc: """ - +import json from typing import List, Optional, Any, Iterator, Dict -from langchain_community.chat_models import ChatSparkLLM -from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk +from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \ + ChatSparkLLM from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string +from langchain_core.messages import BaseMessage, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -31,16 +30,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), spark_api_url=model_credential.get('spark_api_url'), - spark_llm_domain=model_name + spark_llm_domain=model_name, + temperature=model_kwargs.get('temperature', 0.5), + max_tokens=model_kwargs.get('max_tokens', 5), ) + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('completion_tokens', 0) def _stream( self, @@ -58,11 +60,17 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): True, ) for content in self.client.subscribe(timeout=self.request_timeout): - if "data" not in content: + if "data" in content: + delta = content["data"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + cg_chunk = ChatGenerationChunk(message=chunk) + elif "usage" in content: + generation_info = content["usage"] + self.__dict__.setdefault('_last_generation_info', {}).update(generation_info) continue - delta = content["data"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) + else: + continue + if cg_chunk is not None: + if run_manager: + run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk) yield cg_chunk diff --git a/apps/setting/models_provider/impl/xinference_model_provider/__init__.py b/apps/setting/models_provider/impl/xinference_model_provider/__init__.py new file mode 100644 index 000000000..9bad5790a --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py new file mode 100644 index 000000000..200183e6c --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py @@ -0,0 +1,38 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding') + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + if len(exist) == 0: + model.start_down_model_thread() + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + model.embed_query('你好') + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + return self + + api_base = forms.TextInputField('API 域名', required=True) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py new file mode 100644 index 000000000..d6442de32 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -0,0 +1,41 @@ +# coding=utf-8 + +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_type) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + if len(exist) == 0: + raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/icon/xinference_icon_svg b/apps/setting/models_provider/impl/xinference_model_provider/icon/xinference_icon_svg new file mode 100644 index 000000000..fc553ee3c --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/icon/xinference_icon_svg @@ -0,0 +1,5 @@ + + + + diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py new file mode 100644 index 000000000..1cf34aaf8 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py @@ -0,0 +1,24 @@ +# coding=utf-8 +import threading +from typing import Dict + +from langchain_community.embeddings import XinferenceEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XinferenceEmbedding( + model_uid=model_name, + server_url=model_credential.get('api_base'), + ) + + def down_model(self): + self.client.launch_model(model_name=self.model_uid, model_type="embedding") + + def start_down_model_thread(self): + thread = threading.Thread(target=self.down_model) + thread.daemon = True + thread.start() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py new file mode 100644 index 000000000..ed9e4e3c6 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -0,0 +1,39 @@ +# coding=utf-8 + +from typing import List, Dict +from urllib.parse import urlparse, ParseResult +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_base = model_credential.get('api_base', '') + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + optional_params = {} + if 'max_tokens' in model_kwargs: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs: + optional_params['temperature'] = model_kwargs['temperature'] + return XinferenceChatModel( + model=model_name, + openai_api_base=base_url, + openai_api_key=model_credential.get('api_key'), + streaming=model_kwargs.get('streaming', False), + **optional_params + ) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py new file mode 100644 index 000000000..22b1068c3 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -0,0 +1,528 @@ +# coding=utf-8 +import os +from urllib.parse import urlparse, ParseResult + +import requests + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.xinference_model_provider.credential.embedding import \ + XinferenceEmbeddingModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential +from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding +from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel +from smartdoc.conf import PROJECT_DIR + +xinference_llm_model_credential = XinferenceLLMModelCredential() +model_info_list = [ + ModelInfo( + 'aquila2', + 'Aquila2 是一个具有 340 亿参数的大规模语言模型,支持中英文双语。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'aquila2-chat', + 'Aquila2 Chat 是一个聊天模型版本的 Aquila2,支持中英文双语。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'aquila2-chat-16k', + 'Aquila2 Chat 16K 是一个聊天模型版本的 Aquila2,支持长达 16K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'baichuan', + 'Baichuan 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'baichuan-2', + 'Baichuan 2 是 Baichuan 的更新版本,具有更高的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'baichuan-2-chat', + 'Baichuan 2 Chat 是一个聊天模型版本的 Baichuan 2。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'baichuan-chat', + 'Baichuan Chat 是一个聊天模型版本的 Baichuan。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'c4ai-command-r-v01', + 'C4AI Command R V01 是一个用于执行命令的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm', + 'ChatGLM 是一个聊天模型,特别擅长中文对话。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm2', + 'ChatGLM2 是 ChatGLM 的更新版本,具有更好的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm2-32k', + 'ChatGLM2 32K 是一个聊天模型版本的 ChatGLM2,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm3', + 'ChatGLM3 是 ChatGLM 的第三个版本,具有更高的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm3-128k', + 'ChatGLM3 128K 是一个聊天模型版本的 ChatGLM3,支持长达 128K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'chatglm3-32k', + 'ChatGLM3 32K 是一个聊天模型版本的 ChatGLM3,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'code-llama', + 'Code Llama 是一个专门用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'code-llama-instruct', + 'Code Llama Instruct 是 Code Llama 的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'code-llama-python', + 'Code Llama Python 是一个专门用于 Python 代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codegeex4', + 'CodeGeeX4 是一个用于代码生成的语言模型,具有较高的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeqwen1.5', + 'CodeQwen 1.5 是一个用于代码生成的语言模型,具有较高的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeqwen1.5-chat', + 'CodeQwen 1.5 Chat 是一个聊天模型版本的 CodeQwen 1.5。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeshell', + 'CodeShell 是一个用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeshell-chat', + 'CodeShell Chat 是一个聊天模型版本的 CodeShell。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codestral-v0.1', + 'CodeStral V0.1 是一个用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'cogvlm2', + 'CogVLM2 是一个视觉语言模型,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'csg-wukong-chat-v0.1', + 'CSG Wukong Chat V0.1 是一个聊天模型版本的 CSG Wukong。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek', + 'Deepseek 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-chat', + 'Deepseek Chat 是一个聊天模型版本的 Deepseek。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-coder', + 'Deepseek Coder 是一个专为代码生成设计的模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-coder-instruct', + 'Deepseek Coder Instruct 是 Deepseek Coder 的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-vl-chat', + 'Deepseek VL Chat 是 Deepseek 的视觉语言聊天模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'falcon', + 'Falcon 是一个开源的 Transformer 解码器模型,具有 400 亿参数,旨在生成高质量的文本。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'falcon-instruct', + 'Falcon Instruct 是 Falcon 语言模型的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gemma-2-it', + 'GEMMA-2-IT 是一个基于 GEMMA-2 的意大利语模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gemma-it', + 'GEMMA-IT 是一个基于 GEMMA 的意大利语模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-3.5-turbo', + 'GPT-3.5 Turbo 是一个高效能的通用语言模型,适用于多种应用场景。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-4', + 'GPT-4 是一个强大的多模态模型,不仅支持文本输入,还支持图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-4-vision-preview', + 'GPT-4 Vision Preview 是 GPT-4 的视觉预览版本,支持图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt4all', + 'GPT4All 是一个开源的多模态模型,支持文本和图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2', + 'Llama2 是一个具有 700 亿参数的大规模语言模型,支持多种语言。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2-chat', + 'Llama2 Chat 是一个聊天模型版本的 Llama2,支持多种语言。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2-chat-32k', + 'Llama2 Chat 32K 是一个聊天模型版本的 Llama2,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'moss', + 'MOSS 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'moss-chat', + 'MOSS Chat 是一个聊天模型版本的 MOSS。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen', + 'Qwen 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-chat', + 'Qwen Chat 是一个聊天模型版本的 Qwen。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-chat-32k', + 'Qwen Chat 32K 是一个聊天模型版本的 Qwen,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-code', + 'Qwen Code 是一个专门用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-code-chat', + 'Qwen Code Chat 是一个聊天模型版本的 Qwen Code。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-vl', + 'Qwen VL 是 Qwen 的视觉语言模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-vl-chat', + 'Qwen VL Chat 是 Qwen VL 的聊天模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2', + 'Spark2 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-chat', + 'Spark2 Chat 是一个聊天模型版本的 Spark2。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-chat-32k', + 'Spark2 Chat 32K 是一个聊天模型版本的 Spark2,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-code', + 'Spark2 Code 是一个专门用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-code-chat', + 'Spark2 Code Chat 是一个聊天模型版本的 Spark2 Code。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-vl', + 'Spark2 VL 是 Spark2 的视觉语言模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'spark2-vl-chat', + 'Spark2 VL Chat 是 Spark2 VL 的聊天模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), +] + +xinference_embedding_model_credential = XinferenceEmbeddingModelCredential() + +# 生成embedding_model_info列表 +embedding_model_info = [ + ModelInfo('bce-embedding-base_v1', 'BCE 嵌入模型的基础版本。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-en', 'BGE 英语基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-en-v1.5', 'BGE 英语基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-zh', 'BGE 中文基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-zh-v1.5', 'BGE 中文基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-en', 'BGE 英语大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-en-v1.5', 'BGE 英语大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh', 'BGE 中文大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh-noinstruct', 'BGE 中文大型版本的嵌入模型,无指令调整。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh-v1.5', 'BGE 中文大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-m3', 'BGE M3 版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('bge-small-en-v1.5', 'BGE 英语小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-small-zh', 'BGE 中文小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-small-zh-v1.5', 'BGE 中文小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('e5-large-v2', 'E5 大型版本 2 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('gte-base', 'GTE 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('gte-large', 'GTE 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-base-en', 'Jina 嵌入模型的英语基础版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-base-zh', 'Jina 嵌入模型的中文基础版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-small-en', 'Jina 嵌入模型的英语小型版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('m3e-base', 'M3E 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('m3e-large', 'M3E 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('m3e-small', 'M3E 小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('multilingual-e5-large', '多语言大型版本的 E5 嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese', 'Text2Vec 的中文基础版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese-paraphrase', 'Text2Vec 的中文基础版本的同义句嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese-sentence', 'Text2Vec 的中文基础版本的句子嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-multilingual', 'Text2Vec 的多语言基础版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), +] + +model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo( + 'phi3', + 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', + ModelTypeConst.LLM, xinference_llm_model_credential, XinferenceChatModel)) + .append_model_info_list( + embedding_model_info).append_default_model_info( + ModelInfo( + '', + '', + ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding)) + .build()) + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class XinferenceModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_xinference_provider', name='Xinference', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xinference_model_provider', 'icon', + 'xinference_icon_svg'))) + + @staticmethod + def get_base_model_list(api_base, model_type): + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + r = requests.request(method="GET", url=f"{base_url}/models", timeout=5) + r.raise_for_status() + model_list = r.json().get('data') + return [model for model in model_list if model.get('model_type') == model_type] + + @staticmethod + def get_model_info_by_name(model_list, model_name): + if model_list is None: + return [] + return [model for model in model_list if model.get('model_name') == model_name] diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index 86c5b1a47..0de026532 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -9,26 +9,100 @@ from typing import List, Dict from langchain_community.chat_models import ChatZhipuAI +from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ + _convert_delta_to_message_chunk from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +import json +import logging +import time +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) + +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage +) +from langchain_core.outputs import ChatGenerationChunk class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): zhipuai_chat = ZhipuChatModel( temperature=0.5, api_key=model_credential.get('api_key'), - model=model_name + model=model_name, + max_tokens=model_kwargs.get('max_tokens', 5) ) return zhipuai_chat + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the chat response in chunks.""" + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + if self.zhipuai_api_base is None: + raise ValueError("Did not find zhipu_api_base.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + _truncate_params(payload) + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + + default_chunk_class = AIMessageChunk + import httpx + + with httpx.Client(headers=headers, timeout=60) as client: + with connect_sse( + client, "POST", self.zhipuai_api_base, json=payload + ) as event_source: + for sse in event_source.iter_sse(): + chunk = json.loads(sse.data) + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + generation_info = {} + if "usage" in chunk: + generation_info = chunk["usage"] + self.__dict__.setdefault('_last_generation_info', {}).update(generation_info) + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + finish_reason = choice.get("finish_reason", None) + + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if finish_reason is not None: + break diff --git a/pyproject.toml b/pyproject.toml index 8ea907742..f7d5b025d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ gevent = "^24.2.1" boto3 = "^1.34.151" langchain-aws = "^0.1.13" tencentcloud-sdk-python = "^3.0.1205" +xinference-client = "^0.14.0.post1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"