feat: add function to retrieve default parameters for embedding models
--bug=1063177 --user=刘瑞斌 【知识库】-知识库使用的模型更换维度参数值并重新向量化后,命中测试、检索报错 https://www.tapd.cn/62980211/s/1792117v3.2
parent
2de6bd2018
commit
ed19db07d1
|
|
@ -112,6 +112,21 @@ class ProblemParagraphManage:
|
||||||
], problem_paragraph_mapping_list
|
], problem_paragraph_mapping_list
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_embedding_model_default_params(model):
|
||||||
|
def convert_to_int(value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
return {
|
||||||
|
p.get('field'): convert_to_int(p.get('default_value'))
|
||||||
|
for p in model.model_params_form
|
||||||
|
if p.get('default_value') is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||||
|
|
@ -119,17 +134,29 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||||
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
||||||
if len(knowledge_list) == 0:
|
if len(knowledge_list) == 0:
|
||||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||||
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
|
|
||||||
lambda _id: get_model(knowledge_list[0].embedding_model))
|
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
|
||||||
|
|
||||||
|
return ModelManage.get_model(
|
||||||
|
str(knowledge_list[0].embedding_model_id),
|
||||||
|
lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
||||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
|
||||||
|
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||||
|
|
||||||
|
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||||
|
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_by_knowledge(knowledge):
|
def get_embedding_model_by_knowledge(knowledge):
|
||||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||||
|
|
||||||
|
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||||
|
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model_id_by_knowledge_id(knowledge_id):
|
def get_embedding_model_id_by_knowledge_id(knowledge_id):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
|
||||||
UpdateEmbeddingDocumentIdArgs
|
UpdateEmbeddingDocumentIdArgs
|
||||||
from common.utils.logger import maxkb_logger
|
from common.utils.logger import maxkb_logger
|
||||||
from knowledge.models import Document, TaskType, State
|
from knowledge.models import Document, TaskType, State
|
||||||
from knowledge.serializers.common import drop_knowledge_index
|
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
|
||||||
from models_provider.models import Model
|
from models_provider.models import Model
|
||||||
from models_provider.tools import get_model
|
from models_provider.tools import get_model
|
||||||
from ops import celery_app
|
from ops import celery_app
|
||||||
|
|
@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
|
||||||
try:
|
try:
|
||||||
model = QuerySet(Model).filter(id=model_id).first()
|
model = QuerySet(Model).filter(id=model_id).first()
|
||||||
|
|
||||||
def convert_to_int(value):
|
default_params = get_embedding_model_default_params(model)
|
||||||
if isinstance(value, str):
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError:
|
|
||||||
return value
|
|
||||||
return value
|
|
||||||
|
|
||||||
s = {
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||||
p.get('field'): convert_to_int(p.get('default_value'))
|
|
||||||
for p in model.model_params_form
|
|
||||||
if p.get('default_value') is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception_handler(e)
|
exception_handler(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue