diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index cba5186a7..e9174d3a6 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -6,11 +6,18 @@ @Author :Brian Yang @Date :5/13/24 7:40 AM """ -from typing import List, Dict +from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast +from google.ai.generativelanguage_v1 import GenerateContentResponse +from google.generativeai.responder import ToolDict +from google.generativeai.types import FunctionDeclarationType, SafetySettingDict +from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.outputs import ChatGenerationChunk from langchain_google_genai import ChatGoogleGenerativeAI - +from langchain_google_genai._function_utils import _ToolConfigDict +from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result +from google.generativeai.types import Tool as GoogleTool from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -36,10 +43,49 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): ) return gemini_chat + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('input_tokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None, + functions: Optional[Sequence[FunctionDeclarationType]] = None, + safety_settings: Optional[SafetySettingDict] = None, + tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + generation_config: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self._prepare_request( + messages, + stop=stop, + tools=tools, + functions=functions, + safety_settings=safety_settings, + tool_config=tool_config, + generation_config=generation_config, + ) + response: GenerateContentResponse = _chat_with_retry( + request=request, + generation_method=self.client.stream_generate_content, + **kwargs, + metadata=self.default_metadata, + ) + for chunk in response: + _chat_result = _response_to_result(chunk, stream=True) + gen = cast(ChatGenerationChunk, _chat_result.generations[0]) + if gen.message: + token_usage = gen.message.usage_metadata + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + if run_manager: + run_manager.on_llm_new_token(gen.text) + yield gen