diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 649d04de4..06120454c 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -140,14 +140,14 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List): raise Exception("知识库未向量模型不一致") if len(dataset_list) == 0: raise Exception("知识库设置错误,请重新设置知识库") - return ModelManage.get_model(str(dataset_list[0].id), - lambda _id: get_model(dataset_list[0].embedding_mode)) + return ModelManage.get_model(str(dataset_list[0].embedding_mode_id), + lambda _id: get_model(dataset_list[0].embedding_mode)) def get_embedding_model_by_dataset_id(dataset_id: str): dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() - return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) + return ModelManage.get_model(dataset.embedding_mode_id, lambda _id: get_model(dataset.embedding_mode)) def get_embedding_model_by_dataset(dataset): - return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) + return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode))