From e6efa1a07411192eed86d949ba4c81058f2b4d98 Mon Sep 17 00:00:00 2001 From: panyy Date: Mon, 29 Jun 2026 17:36:26 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E9=99=90=E5=88=B6=E5=BF=85=E9=A1=BBapi-ke?= =?UTF-8?q?y=E6=88=96=E7=99=BB=E5=BD=955?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/local_model_provider/model/embedding/model.py | 9 ++++++++- .../impl/local_model_provider/model/embedding/web.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/model.py b/apps/models_provider/impl/local_model_provider/model/embedding/model.py index 7ebc41cb1..a9e7aad69 100644 --- a/apps/models_provider/impl/local_model_provider/model/embedding/model.py +++ b/apps/models_provider/impl/local_model_provider/model/embedding/model.py @@ -6,6 +6,7 @@ @date:2025/11/5 15:26 @desc: """ +import os from typing import Dict from langchain_huggingface import HuggingFaceEmbeddings @@ -20,7 +21,13 @@ class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + cache_folder = model_credential.get('cache_folder') + model_path = model_name + if cache_folder and not os.path.isabs(model_name): + local_model_path = os.path.join(cache_folder, model_name) + if os.path.isdir(local_model_path): + model_path = local_model_path + return LocalEmbedding(model_name=model_path, cache_folder=cache_folder, model_kwargs={'device': model_credential.get('device')}, encode_kwargs={'normalize_embeddings': True} ) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/web.py b/apps/models_provider/impl/local_model_provider/model/embedding/web.py index bfc22bc9b..42c780adc 100644 --- a/apps/models_provider/impl/local_model_provider/model/embedding/web.py +++ b/apps/models_provider/impl/local_model_provider/model/embedding/web.py @@ -46,7 +46,7 @@ class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' prefix = CONFIG.get_admin_path() res = requests.post( - f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents', + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_documents', {'texts': texts}) result = res.json() if result.get('code', 500) == 200: