UnisKB/apps/models_provider/impl/ollama_model_provider/model/llm.py

49 lines
1.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/3/6 11:48
@desc:
"""
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_ollama.chat_models import ChatOllama
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class OllamaChatModel(MaxKBBaseModel, ChatOllama):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OllamaChatModel(model=model_name, base_url=base_url,
stream=True, **optional_params)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))