UnisKB/apps/models_provider/impl/gemini_model_provider/model/llm.py

101 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""
from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast
from google.ai.generativelanguage_v1 import GenerateContentResponse
from google.ai.generativelanguage_v1beta.types import (
Tool as GoogleTool,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result, \
_FunctionDeclarationType
from langchain_google_genai._common import (
SafetySettingDict,
)
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key'),
**optional_params
)
return gemini_chat
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
generation_method=self.client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
)
for chunk in response:
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if gen.message:
token_usage = gen.message.usage_metadata
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen