diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py index 8b7fa053e..d1eef33d4 100644 --- a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -47,14 +47,15 @@ class BaseRerankerNode(IRerankerNode): def execute(self, question, reranker_setting, reranker_list, reranker_model_id, **kwargs) -> NodeResult: documents = merge_reranker_list(reranker_list) + top_n = reranker_setting.get('top_n', 3) self.context['document_list'] = documents self.context['question'] = question reranker_model = get_model_instance_by_model_user_id(reranker_model_id, - self.flow_params_serializer.data.get('user_id')) + self.flow_params_serializer.data.get('user_id'), + top_n=top_n) result = reranker_model.compress_documents( [Document(page_content=document) for document in documents if document is not None and len(document) > 0], question) - top_n = reranker_setting.get('top_n', 3) similarity = reranker_setting.get('similarity', 0.6) max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000) r = filter_result(result, max_paragraph_char_number, top_n, similarity) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py index f32e1ee94..ed2db0f91 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -26,7 +26,7 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name, - api_key=model_credential.get('api_key')) + api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3)) top_n: Optional[int] = 3