feat: application flow (#3152)
parent
0c9d8ccf71
commit
896fb5fa52
|
|
@ -0,0 +1,157 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: I_base_chat_pipeline.py
|
||||
@date:2024/1/9 17:25
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class ParagraphPipelineModel:
|
||||
|
||||
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
|
||||
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
|
||||
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
|
||||
self.id = _id
|
||||
self.document_id = document_id
|
||||
self.dataset_id = dataset_id
|
||||
self.content = content
|
||||
self.title = title
|
||||
self.status = status,
|
||||
self.is_active = is_active
|
||||
self.comprehensive_score = comprehensive_score
|
||||
self.similarity = similarity
|
||||
self.dataset_name = dataset_name
|
||||
self.document_name = document_name
|
||||
self.hit_handling_method = hit_handling_method
|
||||
self.directly_return_similarity = directly_return_similarity
|
||||
self.meta = meta
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'document_id': self.document_id,
|
||||
'dataset_id': self.dataset_id,
|
||||
'content': self.content,
|
||||
'title': self.title,
|
||||
'status': self.status,
|
||||
'is_active': self.is_active,
|
||||
'comprehensive_score': self.comprehensive_score,
|
||||
'similarity': self.similarity,
|
||||
'dataset_name': self.dataset_name,
|
||||
'document_name': self.document_name,
|
||||
'meta': self.meta,
|
||||
}
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.similarity = None
|
||||
self.paragraph = {}
|
||||
self.comprehensive_score = None
|
||||
self.document_name = None
|
||||
self.dataset_name = None
|
||||
self.hit_handling_method = None
|
||||
self.directly_return_similarity = 0.9
|
||||
self.meta = {}
|
||||
|
||||
def add_paragraph(self, paragraph):
|
||||
if isinstance(paragraph, Paragraph):
|
||||
self.paragraph = {'id': paragraph.id,
|
||||
'document_id': paragraph.document_id,
|
||||
'dataset_id': paragraph.dataset_id,
|
||||
'content': paragraph.content,
|
||||
'title': paragraph.title,
|
||||
'status': paragraph.status,
|
||||
'is_active': paragraph.is_active,
|
||||
}
|
||||
else:
|
||||
self.paragraph = paragraph
|
||||
return self
|
||||
|
||||
def add_dataset_name(self, dataset_name):
|
||||
self.dataset_name = dataset_name
|
||||
return self
|
||||
|
||||
def add_document_name(self, document_name):
|
||||
self.document_name = document_name
|
||||
return self
|
||||
|
||||
def add_hit_handling_method(self, hit_handling_method):
|
||||
self.hit_handling_method = hit_handling_method
|
||||
return self
|
||||
|
||||
def add_directly_return_similarity(self, directly_return_similarity):
|
||||
self.directly_return_similarity = directly_return_similarity
|
||||
return self
|
||||
|
||||
def add_comprehensive_score(self, comprehensive_score: float):
|
||||
self.comprehensive_score = comprehensive_score
|
||||
return self
|
||||
|
||||
def add_similarity(self, similarity: float):
|
||||
self.similarity = similarity
|
||||
return self
|
||||
|
||||
def add_meta(self, meta: dict):
|
||||
self.meta = meta
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
||||
str(self.paragraph.get('dataset_id')),
|
||||
self.paragraph.get('content'), self.paragraph.get('title'),
|
||||
self.paragraph.get('status'),
|
||||
self.paragraph.get('is_active'),
|
||||
self.comprehensive_score, self.similarity, self.dataset_name,
|
||||
self.document_name, self.hit_handling_method, self.directly_return_similarity,
|
||||
self.meta)
|
||||
|
||||
|
||||
class IBaseChatPipelineStep:
|
||||
def __init__(self):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.context = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
|
||||
pass
|
||||
|
||||
def valid_args(self, manage):
|
||||
step_serializer_clazz = self.get_step_serializer(manage)
|
||||
step_serializer = step_serializer_clazz(data=manage.context)
|
||||
step_serializer.is_valid(raise_exception=True)
|
||||
self.context['step_args'] = step_serializer.data
|
||||
|
||||
def run(self, manage):
|
||||
"""
|
||||
|
||||
:param manage: 步骤管理器
|
||||
:return: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.context['start_time'] = start_time
|
||||
# 校验参数,
|
||||
self.valid_args(manage)
|
||||
self._run(manage)
|
||||
self.context['run_time'] = time.time() - start_time
|
||||
|
||||
def _run(self, manage):
|
||||
pass
|
||||
|
||||
def execute(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
"""
|
||||
运行详情
|
||||
:return: 步骤详情
|
||||
"""
|
||||
return None
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 17:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: pipeline_manage.py
|
||||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Type, Dict
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from common.handle.base_to_response import BaseToResponse
|
||||
from common.handle.impl.response.system_to_response import SystemToResponse
|
||||
|
||||
|
||||
class PipelineManage:
|
||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
|
||||
base_to_response: BaseToResponse = SystemToResponse()):
|
||||
# 步骤执行器
|
||||
self.step_list = [step() for step in step_list]
|
||||
# 上下文
|
||||
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||
self.base_to_response = base_to_response
|
||||
|
||||
def run(self, context: Dict = None):
|
||||
self.context['start_time'] = time.time()
|
||||
if context is not None:
|
||||
for key, value in context.items():
|
||||
self.context[key] = value
|
||||
for step in self.step_list:
|
||||
step.run(self)
|
||||
|
||||
def get_details(self):
|
||||
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
|
||||
filter(lambda r: r is not None,
|
||||
[row.get_details(self) for row in self.step_list])], {})
|
||||
|
||||
def get_base_to_response(self):
|
||||
return self.base_to_response
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||
self.base_to_response = SystemToResponse()
|
||||
|
||||
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||
self.step_list.append(step)
|
||||
return self
|
||||
|
||||
def add_base_to_response(self, base_to_response: BaseToResponse):
|
||||
self.base_to_response = base_to_response
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_step.py
|
||||
@date:2024/1/9 18:17
|
||||
@desc: 对话
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.serializers.application_serializers import NoReferencesSetting
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ModelField(serializers.Field):
|
||||
def to_internal_value(self, data):
|
||||
if not isinstance(data, BaseChatModel):
|
||||
self.fail(_('Model type error'), value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class MessageField(serializers.Field):
|
||||
def to_internal_value(self, data):
|
||||
if not isinstance(data, BaseMessage):
|
||||
self.fail(_('Message type error'), value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class PostResponseHandler:
|
||||
@abstractmethod
|
||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
|
||||
answer_text,
|
||||
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class IChatStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 对话列表
|
||||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||
error_messages=ErrMessage.list(_("Conversation list")))
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
|
||||
# 对话id
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
|
||||
# 用户问题
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
|
||||
# 后置处理器
|
||||
post_response_handler = InstanceField(model_type=PostResponseHandler,
|
||||
error_messages=ErrMessage.base(_("Post-processor")))
|
||||
# 补全问题
|
||||
padding_problem_text = serializers.CharField(required=False,
|
||||
error_messages=ErrMessage.base(_("Completion Question")))
|
||||
# 是否使用流的形式输出
|
||||
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
|
||||
# 未查询到引用分段
|
||||
no_references_setting = NoReferencesSetting(required=True,
|
||||
error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
|
||||
model_setting = serializers.DictField(required=True, allow_null=True,
|
||||
error_messages=ErrMessage.dict(_("Model settings")))
|
||||
|
||||
model_params_setting = serializers.DictField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
message_list: List = self.initial_data.get('message_list')
|
||||
for message in message_list:
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise Exception(_("message type error"))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PipelineManage):
|
||||
chat_result = self.execute(**self.context['step_args'], manage=manage)
|
||||
manage.context['chat_result'] = chat_result
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, message_list: List[BaseMessage],
|
||||
chat_id, problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
model_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
|
||||
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_chat_step.py
|
||||
@date:2024/1/9 18:25
|
||||
@desc: 对话step Base实现
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import StreamingHttpResponse
|
||||
from django.utils.translation import gettext as _
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage, AIMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from rest_framework import status
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from application.flow.tools import Reasoning
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def add_access_num(client_id=None, client_type=None, application_id=None):
|
||||
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
|
||||
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
|
||||
application_id=application_id)
|
||||
.first())
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
application_public_access_client.save()
|
||||
|
||||
|
||||
def write_context(step, manage, request_token, response_token, all_text):
|
||||
step.context['message_tokens'] = request_token
|
||||
step.context['answer_tokens'] = response_token
|
||||
current_time = time.time()
|
||||
step.context['answer_text'] = all_text
|
||||
step.context['run_time'] = current_time - step.context['start_time']
|
||||
manage.context['run_time'] = current_time - manage.context['start_time']
|
||||
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
|
||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||
|
||||
|
||||
def event_content(response,
|
||||
chat_id,
|
||||
chat_record_id,
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
post_response_handler: PostResponseHandler,
|
||||
manage,
|
||||
step,
|
||||
chat_model,
|
||||
message_list: List[BaseMessage],
|
||||
problem_text: str,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None,
|
||||
is_ai_chat: bool = None,
|
||||
model_setting=None):
|
||||
if model_setting is None:
|
||||
model_setting = {}
|
||||
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
|
||||
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
|
||||
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
|
||||
reasoning = Reasoning(reasoning_content_start,
|
||||
reasoning_content_end)
|
||||
all_text = ''
|
||||
reasoning_content = ''
|
||||
try:
|
||||
response_reasoning_content = False
|
||||
for chunk in response:
|
||||
reasoning_chunk = reasoning.get_reasoning_content(chunk)
|
||||
content_chunk = reasoning_chunk.get('content')
|
||||
if 'reasoning_content' in chunk.additional_kwargs:
|
||||
response_reasoning_content = True
|
||||
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
||||
all_text += content_chunk
|
||||
if reasoning_content_chunk is None:
|
||||
reasoning_content_chunk = ''
|
||||
reasoning_content += reasoning_content_chunk
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], content_chunk,
|
||||
False,
|
||||
0, 0, {'node_is_end': False,
|
||||
'view_type': 'many_view',
|
||||
'node_type': 'ai-chat-node',
|
||||
'real_node_id': 'ai-chat-node',
|
||||
'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
|
||||
reasoning_chunk = reasoning.get_end_reasoning_content()
|
||||
all_text += reasoning_chunk.get('content')
|
||||
reasoning_content_chunk = ""
|
||||
if not response_reasoning_content:
|
||||
reasoning_content_chunk = reasoning_chunk.get(
|
||||
'reasoning_content')
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], reasoning_chunk.get('content'),
|
||||
False,
|
||||
0, 0, {'node_is_end': False,
|
||||
'view_type': 'many_view',
|
||||
'node_type': 'ai-chat-node',
|
||||
'real_node_id': 'ai-chat-node',
|
||||
'reasoning_content'
|
||||
: reasoning_content_chunk if reasoning_content_enable else ''})
|
||||
# 获取token
|
||||
if is_ai_chat:
|
||||
try:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(all_text)
|
||||
except Exception as e:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
else:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
write_context(step, manage, request_token, response_token, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, step, padding_problem_text, client_id,
|
||||
reasoning_content=reasoning_content if reasoning_content_enable else ''
|
||||
, asker=asker)
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], '', True,
|
||||
request_token, response_token,
|
||||
{'node_is_end': True, 'view_type': 'many_view',
|
||||
'node_type': 'ai-chat-node'})
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
all_text = 'Exception:' + str(e)
|
||||
write_context(step, manage, 0, 0, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], all_text,
|
||||
False,
|
||||
0, 0, {'node_is_end': False,
|
||||
'view_type': 'many_view',
|
||||
'node_type': 'ai-chat-node',
|
||||
'real_node_id': 'ai-chat-node',
|
||||
'reasoning_content': ''})
|
||||
|
||||
|
||||
class BaseChatStep(IChatStep):
|
||||
def execute(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
model_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
stream: bool = True,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
model_params_setting=None,
|
||||
model_setting=None,
|
||||
**kwargs):
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
|
||||
**model_params_setting) if model_id is not None else None
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||
model_setting)
|
||||
else:
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||
model_setting)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
'step_type': 'chat_step',
|
||||
'run_time': self.context['run_time'],
|
||||
'model_id': str(manage.context['model_id']),
|
||||
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
|
||||
self.context['answer_text']),
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens'],
|
||||
'cost': 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_stream_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None,
|
||||
problem_text=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
for paragraph in paragraph_list if (
|
||||
paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return iter(directly_return_chunk_list), False
|
||||
elif len(paragraph_list) == 0 and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return iter(
|
||||
[AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False
|
||||
if chat_model is None:
|
||||
return iter([AIMessageChunk(
|
||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
|
||||
else:
|
||||
return chat_model.stream(message_list), True
|
||||
|
||||
def execute_stream(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
model_setting=None):
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text)
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def get_block_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None,
|
||||
problem_text=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
for paragraph in paragraph_list if (
|
||||
paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return directly_return_chunk_list[0], False
|
||||
elif len(paragraph_list) == 0 and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False
|
||||
if chat_model is None:
|
||||
return AIMessage(
|
||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
|
||||
else:
|
||||
return chat_model.invoke(message_list), True
|
||||
|
||||
def execute_block(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None, no_references_setting=None,
|
||||
model_setting=None):
|
||||
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
|
||||
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
|
||||
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
|
||||
reasoning = Reasoning(reasoning_content_start,
|
||||
reasoning_content_end)
|
||||
chat_record_id = uuid.uuid1()
|
||||
# 调用模型
|
||||
try:
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text)
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
else:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
write_context(self, manage, request_token, response_token, chat_result.content)
|
||||
reasoning_result = reasoning.get_reasoning_content(chat_result)
|
||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
if 'reasoning_content' in chat_result.response_metadata:
|
||||
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
|
||||
'reasoning_content')
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
content, manage, self, padding_problem_text, client_id,
|
||||
reasoning_content=reasoning_content if reasoning_content_enable else '',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
|
||||
content, True,
|
||||
request_token, response_token,
|
||||
{
|
||||
'reasoning_content': reasoning_content if reasoning_content_enable else '',
|
||||
'answer_list': [{
|
||||
'content': content,
|
||||
'reasoning_content': reasoning_content if reasoning_content_enable else ''
|
||||
}]})
|
||||
except Exception as e:
|
||||
all_text = 'Exception:' + str(e)
|
||||
write_context(self, manage, 0, 0, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, self, padding_problem_text, client_id, reasoning_content='',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
|
||||
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_generate_human_message_step.py
|
||||
@date:2024/1/9 18:15
|
||||
@desc: 生成对话模板
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.models import ChatRecord
|
||||
from application.serializers.application_serializers import NoReferencesSetting
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
|
||||
error_messages=ErrMessage.list(_("Paragraph List")))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list(_("History Questions")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
# 最大携带知识库段落长度
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||
_("Maximum length of the knowledge base paragraph")))
|
||||
# 模板
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char(_("System prompt words (role)")))
|
||||
# 补齐问题
|
||||
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem")))
|
||||
# 未查询到引用分段
|
||||
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PipelineManage):
|
||||
message_list = self.execute(**self.context['step_args'])
|
||||
manage.context['message_list'] = message_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self,
|
||||
problem_text: str,
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
prompt: str,
|
||||
padding_problem_text: str = None,
|
||||
no_references_setting=None,
|
||||
system=None,
|
||||
**kwargs) -> List[BaseMessage]:
|
||||
"""
|
||||
|
||||
:param problem_text: 原始问题文本
|
||||
:param paragraph_list: 段落列表
|
||||
:param history_chat_record: 历史对话记录
|
||||
:param dialogue_number: 多轮对话数量
|
||||
:param max_paragraph_char_number: 最大段落长度
|
||||
:param prompt: 模板
|
||||
:param padding_problem_text 用户修改文本
|
||||
:param kwargs: 其他参数
|
||||
:param no_references_setting: 无引用分段设置
|
||||
:param system 系统提示称
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_generate_human_message_step.py.py
|
||||
@date:2024/1/10 17:50
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||
IGenerateHumanMessageStep
|
||||
from application.models import ChatRecord
|
||||
from common.util.split_model import flat_map
|
||||
|
||||
|
||||
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||
|
||||
def execute(self, problem_text: str,
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
prompt: str,
|
||||
padding_problem_text: str = None,
|
||||
no_references_setting=None,
|
||||
system=None,
|
||||
**kwargs) -> List[BaseMessage]:
|
||||
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
|
||||
'value')
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
if system is not None and len(system) > 0:
|
||||
return [SystemMessage(system), *flat_map(history_message),
|
||||
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
|
||||
no_references_setting)]
|
||||
|
||||
return [*flat_map(history_message),
|
||||
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
|
||||
no_references_setting)]
|
||||
|
||||
@staticmethod
|
||||
def to_human_message(prompt: str,
|
||||
problem: str,
|
||||
max_paragraph_char_number: int,
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
no_references_setting: Dict):
|
||||
if paragraph_list is None or len(paragraph_list) == 0:
|
||||
if no_references_setting.get('status') == 'ai_questioning':
|
||||
return HumanMessage(
|
||||
content=no_references_setting.get('value').replace('{question}', problem))
|
||||
else:
|
||||
return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem))
|
||||
temp_data = ""
|
||||
data_list = []
|
||||
for p in paragraph_list:
|
||||
content = f"{p.title}:{p.content}"
|
||||
temp_data += content
|
||||
if len(temp_data) > max_paragraph_char_number:
|
||||
row_data = content[0:max_paragraph_char_number - len(temp_data)]
|
||||
data_list.append(f"<data>{row_data}</data>")
|
||||
break
|
||||
else:
|
||||
data_list.append(f"<data>{content}</data>")
|
||||
data = "\n".join(data_list)
|
||||
return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem))
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_reset_problem_step.py
|
||||
@date:2024/1/9 18:12
|
||||
@desc: 重写处理问题
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class IResetProblemStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题文本
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question")))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list(_("History Questions")))
|
||||
# 大语言模型
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
|
||||
error_messages=ErrMessage.char(
|
||||
_("Question completion prompt")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PipelineManage):
|
||||
padding_problem = self.execute(**self.context.get('step_args'))
|
||||
# 用户输入问题
|
||||
source_problem_text = self.context.get('step_args').get('problem_text')
|
||||
self.context['problem_text'] = source_problem_text
|
||||
self.context['padding_problem_text'] = padding_problem
|
||||
manage.context['problem_text'] = source_problem_text
|
||||
manage.context['padding_problem_text'] = padding_problem
|
||||
# 累加tokens
|
||||
manage.context['message_tokens'] = manage.context.get('message_tokens', 0) + self.context.get('message_tokens',
|
||||
0)
|
||||
manage.context['answer_tokens'] = manage.context.get('answer_tokens', 0) + self.context.get('answer_tokens', 0)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||
problem_optimization_prompt=None,
|
||||
user_id=None,
|
||||
**kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_reset_problem_step.py
|
||||
@date:2024/1/10 14:35
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
||||
from application.models import ChatRecord
|
||||
from common.util.split_model import flat_map
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
prompt = _(
|
||||
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
|
||||
|
||||
|
||||
class BaseResetProblemStep(IResetProblemStep):
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||
problem_optimization_prompt=None,
|
||||
user_id=None,
|
||||
**kwargs) -> str:
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
|
||||
if chat_model is None:
|
||||
return problem_text
|
||||
start_index = len(history_chat_record) - 3
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
|
||||
message_list = [*flat_map(history_message),
|
||||
HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
|
||||
response = chat_model.invoke(message_list)
|
||||
padding_problem = problem_text
|
||||
if response.content.__contains__("<data>") and response.content.__contains__('</data>'):
|
||||
padding_problem_data = response.content[
|
||||
response.content.index('<data>') + 6:response.content.index('</data>')]
|
||||
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
|
||||
padding_problem = padding_problem_data
|
||||
elif len(response.content) > 0:
|
||||
padding_problem = response.content
|
||||
|
||||
try:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(padding_problem)
|
||||
except Exception as e:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
self.context['message_tokens'] = request_token
|
||||
self.context['answer_tokens'] = response_token
|
||||
return padding_problem
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
'step_type': 'problem_padding',
|
||||
'run_time': self.context['run_time'],
|
||||
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
|
||||
'message_tokens': self.context.get('message_tokens', 0),
|
||||
'answer_tokens': self.context.get('answer_tokens', 0),
|
||||
'cost': 0,
|
||||
'padding_problem_text': self.context.get('padding_problem_text'),
|
||||
'problem_text': self.context.get("step_args").get('problem_text'),
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:24
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_search_dataset_step.py
|
||||
@date:2024/1/9 18:10
|
||||
@desc: 检索知识库
|
||||
"""
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type
|
||||
|
||||
from django.core import validators
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 原始问题文本
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||
# 系统补全问题文本
|
||||
padding_problem_text = serializers.CharField(required=False,
|
||||
error_messages=ErrMessage.char(_("System completes question text")))
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list(_("Dataset id list")))
|
||||
# 需要排除的文档id
|
||||
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list(_("List of document ids to exclude")))
|
||||
# 需要排除向量id
|
||||
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list(_("List of exclusion vector ids")))
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer(_("Reference segment number")))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||
error_messages=ErrMessage.float(_("Similarity")))
|
||||
search_mode = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message=_("The type only supports embedding|keywords|blend"), code=500)
|
||||
], error_messages=ErrMessage.char(_("Retrieval Mode")))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PipelineManage):
|
||||
paragraph_list = self.execute(**self.context['step_args'])
|
||||
manage.context['paragraph_list'] = paragraph_list
|
||||
self.context['paragraph_list'] = paragraph_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
:param similarity: 相关性
|
||||
:param top_n: 查询多少条
|
||||
:param problem_text: 用户问题
|
||||
:param dataset_id_list: 需要查询的数据集id列表
|
||||
:param exclude_document_id_list: 需要排除的文档id
|
||||
:param exclude_paragraph_id_list: 需要排除段落id
|
||||
:param padding_problem_text 补全问题
|
||||
:param search_mode 检索模式
|
||||
:param user_id 用户id
|
||||
:return: 段落列表
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_search_dataset_step.py
|
||||
@date:2024/1/10 10:33
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.utils.formatting import lazy_format
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception(_("Model does not exist"))
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
|
||||
raise Exception(message)
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
if len(dataset_id_list) == 0:
|
||||
return []
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
self.context['model_name'] = model.name
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||
if embedding_list is None:
|
||||
return []
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
|
||||
filter_embedding_list = [embedding for embedding in embedding_list if
|
||||
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
|
||||
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
|
||||
find_embedding = filter_embedding_list[-1]
|
||||
return (ParagraphPipelineModel.builder()
|
||||
.add_paragraph(paragraph)
|
||||
.add_similarity(find_embedding.get('similarity'))
|
||||
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
|
||||
.add_dataset_name(paragraph.get('dataset_name'))
|
||||
.add_document_name(paragraph.get('document_name'))
|
||||
.add_hit_handling_method(paragraph.get('hit_handling_method'))
|
||||
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
|
||||
.add_meta(paragraph.get('meta'))
|
||||
.build())
|
||||
|
||||
@staticmethod
|
||||
def get_similarity(paragraph, embedding_list: List):
|
||||
filter_embedding_list = [embedding for embedding in embedding_list if
|
||||
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
|
||||
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
|
||||
find_embedding = filter_embedding_list[-1]
|
||||
return find_embedding.get('comprehensive_score')
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def list_paragraph(embedding_list: List, vector):
|
||||
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
|
||||
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||
return []
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
exist_paragraph_list = [row.get('id') for row in paragraph_list]
|
||||
for paragraph_id in paragraph_id_list:
|
||||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
# 如果存在直接返回的则取直接返回段落
|
||||
hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
|
||||
(paragraph.get(
|
||||
'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity(
|
||||
paragraph, embedding_list) >= paragraph.get(
|
||||
'directly_return_similarity'))]
|
||||
if len(hit_handling_method_paragraph) > 0:
|
||||
# 找到评分最高的
|
||||
return [sorted(hit_handling_method_paragraph,
|
||||
key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]]
|
||||
return paragraph_list
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
step_args = self.context['step_args']
|
||||
|
||||
return {
|
||||
'step_type': 'search_step',
|
||||
'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
|
||||
'run_time': self.context['run_time'],
|
||||
'problem_text': step_args.get(
|
||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||
'model_name': self.context.get('model_name'),
|
||||
'message_tokens': 0,
|
||||
'answer_tokens': 0,
|
||||
'cost': 0
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2024/12/11 17:57
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
|
||||
reasoning_content):
|
||||
self.view_type = view_type
|
||||
self.content = content
|
||||
self.reasoning_content = reasoning_content
|
||||
self.runtime_node_id = runtime_node_id
|
||||
self.chat_record_id = chat_record_id
|
||||
self.child_node = child_node
|
||||
self.real_node_id = real_node_id
|
||||
|
||||
def to_dict(self):
|
||||
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
|
||||
'chat_record_id': self.chat_record_id,
|
||||
'child_node': self.child_node,
|
||||
'reasoning_content': self.reasoning_content,
|
||||
'real_node_id': self.real_node_id}
|
||||
|
||||
|
||||
class NodeChunk:
|
||||
def __init__(self):
|
||||
self.status = 0
|
||||
self.chunk_list = []
|
||||
|
||||
def add_chunk(self, chunk):
|
||||
self.chunk_list.append(chunk)
|
||||
|
||||
def end(self, chunk=None):
|
||||
if chunk is not None:
|
||||
self.add_chunk(chunk)
|
||||
self.status = 200
|
||||
|
||||
def is_end(self):
|
||||
return self.status == 200
|
||||
|
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "base-node",
|
||||
"type": "base-node",
|
||||
"x": 360,
|
||||
"y": 2810,
|
||||
"properties": {
|
||||
"config": {
|
||||
|
||||
},
|
||||
"height": 825.6,
|
||||
"stepName": "基本信息",
|
||||
"node_data": {
|
||||
"desc": "",
|
||||
"name": "maxkbapplication",
|
||||
"prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
|
||||
},
|
||||
"input_field_list": [
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "start-node",
|
||||
"type": "start-node",
|
||||
"x": 430,
|
||||
"y": 3660,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"height": 276,
|
||||
"stepName": "开始",
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"type": "search-dataset-node",
|
||||
"x": 840,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "检索结果的分段列表",
|
||||
"value": "paragraph_list"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段列表",
|
||||
"value": "is_hit_handling_method_list"
|
||||
},
|
||||
{
|
||||
"label": "检索结果",
|
||||
"value": "data"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段内容",
|
||||
"value": "directly_return"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 794,
|
||||
"stepName": "知识库检索",
|
||||
"node_data": {
|
||||
"dataset_id_list": [
|
||||
|
||||
],
|
||||
"dataset_setting": {
|
||||
"top_n": 3,
|
||||
"similarity": 0.6,
|
||||
"search_mode": "embedding",
|
||||
"max_paragraph_char_number": 5000
|
||||
},
|
||||
"question_reference_address": [
|
||||
"start-node",
|
||||
"question"
|
||||
],
|
||||
"source_dataset_id_list": [
|
||||
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"type": "condition-node",
|
||||
"x": 1490,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"width": 600,
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "分支名称",
|
||||
"value": "branch_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 543.675,
|
||||
"stepName": "判断器",
|
||||
"node_data": {
|
||||
"branch": [
|
||||
{
|
||||
"id": "1009",
|
||||
"type": "IF",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"is_hit_handling_method_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4908",
|
||||
"type": "ELSE IF 1",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "161",
|
||||
"type": "ELSE",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"branch_condition_list": [
|
||||
{
|
||||
"index": 0,
|
||||
"height": 121.225,
|
||||
"id": "1009"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"height": 121.225,
|
||||
"id": "4908"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"height": 44,
|
||||
"id": "161"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"type": "reply-node",
|
||||
"x": 2170,
|
||||
"y": 2480,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 378,
|
||||
"stepName": "指定回复",
|
||||
"node_data": {
|
||||
"fields": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"directly_return"
|
||||
],
|
||||
"content": "",
|
||||
"reply_type": "referencing",
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3200,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 对话",
|
||||
"node_data": {
|
||||
"prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3970,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 对话1",
|
||||
"node_data": {
|
||||
"prompt": "{{开始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "start-node",
|
||||
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"startPoint": {
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 700,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 570,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "start-node_right",
|
||||
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
|
||||
},
|
||||
{
|
||||
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"startPoint": {
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1110,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1090,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
|
||||
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
|
||||
},
|
||||
{
|
||||
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1900,
|
||||
"y": 2480
|
||||
},
|
||||
{
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
|
||||
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
|
||||
},
|
||||
{
|
||||
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3200
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
|
||||
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
|
||||
},
|
||||
{
|
||||
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3970
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
|
||||
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "base-node",
|
||||
"type": "base-node",
|
||||
"x": 360,
|
||||
"y": 2810,
|
||||
"properties": {
|
||||
"config": {
|
||||
|
||||
},
|
||||
"height": 825.6,
|
||||
"stepName": "Base",
|
||||
"node_data": {
|
||||
"desc": "",
|
||||
"name": "maxkbapplication",
|
||||
"prologue": "Hello, I am the MaxKB assistant. You can ask me about MaxKB usage issues.\n-What are the main functions of MaxKB?\n-What major language models does MaxKB support?\n-What document types does MaxKB support?"
|
||||
},
|
||||
"input_field_list": [
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "start-node",
|
||||
"type": "start-node",
|
||||
"x": 430,
|
||||
"y": 3660,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"height": 276,
|
||||
"stepName": "Start",
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"type": "search-dataset-node",
|
||||
"x": 840,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "检索结果的分段列表",
|
||||
"value": "paragraph_list"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段列表",
|
||||
"value": "is_hit_handling_method_list"
|
||||
},
|
||||
{
|
||||
"label": "检索结果",
|
||||
"value": "data"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段内容",
|
||||
"value": "directly_return"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 794,
|
||||
"stepName": "Knowledge Search",
|
||||
"node_data": {
|
||||
"dataset_id_list": [
|
||||
|
||||
],
|
||||
"dataset_setting": {
|
||||
"top_n": 3,
|
||||
"similarity": 0.6,
|
||||
"search_mode": "embedding",
|
||||
"max_paragraph_char_number": 5000
|
||||
},
|
||||
"question_reference_address": [
|
||||
"start-node",
|
||||
"question"
|
||||
],
|
||||
"source_dataset_id_list": [
|
||||
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"type": "condition-node",
|
||||
"x": 1490,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"width": 600,
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "分支名称",
|
||||
"value": "branch_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 543.675,
|
||||
"stepName": "Conditional Branch",
|
||||
"node_data": {
|
||||
"branch": [
|
||||
{
|
||||
"id": "1009",
|
||||
"type": "IF",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"is_hit_handling_method_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4908",
|
||||
"type": "ELSE IF 1",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "161",
|
||||
"type": "ELSE",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"branch_condition_list": [
|
||||
{
|
||||
"index": 0,
|
||||
"height": 121.225,
|
||||
"id": "1009"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"height": 121.225,
|
||||
"id": "4908"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"height": 44,
|
||||
"id": "161"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"type": "reply-node",
|
||||
"x": 2170,
|
||||
"y": 2480,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 378,
|
||||
"stepName": "Specified Reply",
|
||||
"node_data": {
|
||||
"fields": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"directly_return"
|
||||
],
|
||||
"content": "",
|
||||
"reply_type": "referencing",
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3200,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI Chat",
|
||||
"node_data": {
|
||||
"prompt": "Known information:\n{{Knowledge Search.data}}\nQuestion:\n{{Start.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3970,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI Chat1",
|
||||
"node_data": {
|
||||
"prompt": "{{Start.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "start-node",
|
||||
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"startPoint": {
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 700,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 570,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "start-node_right",
|
||||
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
|
||||
},
|
||||
{
|
||||
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"startPoint": {
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1110,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1090,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
|
||||
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
|
||||
},
|
||||
{
|
||||
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1900,
|
||||
"y": 2480
|
||||
},
|
||||
{
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
|
||||
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
|
||||
},
|
||||
{
|
||||
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3200
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
|
||||
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
|
||||
},
|
||||
{
|
||||
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3970
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
|
||||
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "base-node",
|
||||
"type": "base-node",
|
||||
"x": 360,
|
||||
"y": 2810,
|
||||
"properties": {
|
||||
"config": {
|
||||
|
||||
},
|
||||
"height": 825.6,
|
||||
"stepName": "基本信息",
|
||||
"node_data": {
|
||||
"desc": "",
|
||||
"name": "maxkbapplication",
|
||||
"prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
|
||||
},
|
||||
"input_field_list": [
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "start-node",
|
||||
"type": "start-node",
|
||||
"x": 430,
|
||||
"y": 3660,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"height": 276,
|
||||
"stepName": "开始",
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"type": "search-dataset-node",
|
||||
"x": 840,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "检索结果的分段列表",
|
||||
"value": "paragraph_list"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段列表",
|
||||
"value": "is_hit_handling_method_list"
|
||||
},
|
||||
{
|
||||
"label": "检索结果",
|
||||
"value": "data"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段内容",
|
||||
"value": "directly_return"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 794,
|
||||
"stepName": "知识库检索",
|
||||
"node_data": {
|
||||
"dataset_id_list": [
|
||||
|
||||
],
|
||||
"dataset_setting": {
|
||||
"top_n": 3,
|
||||
"similarity": 0.6,
|
||||
"search_mode": "embedding",
|
||||
"max_paragraph_char_number": 5000
|
||||
},
|
||||
"question_reference_address": [
|
||||
"start-node",
|
||||
"question"
|
||||
],
|
||||
"source_dataset_id_list": [
|
||||
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"type": "condition-node",
|
||||
"x": 1490,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"width": 600,
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "分支名称",
|
||||
"value": "branch_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 543.675,
|
||||
"stepName": "判断器",
|
||||
"node_data": {
|
||||
"branch": [
|
||||
{
|
||||
"id": "1009",
|
||||
"type": "IF",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"is_hit_handling_method_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4908",
|
||||
"type": "ELSE IF 1",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "161",
|
||||
"type": "ELSE",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"branch_condition_list": [
|
||||
{
|
||||
"index": 0,
|
||||
"height": 121.225,
|
||||
"id": "1009"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"height": 121.225,
|
||||
"id": "4908"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"height": 44,
|
||||
"id": "161"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"type": "reply-node",
|
||||
"x": 2170,
|
||||
"y": 2480,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 378,
|
||||
"stepName": "指定回复",
|
||||
"node_data": {
|
||||
"fields": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"directly_return"
|
||||
],
|
||||
"content": "",
|
||||
"reply_type": "referencing",
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3200,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 对话",
|
||||
"node_data": {
|
||||
"prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3970,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 对话1",
|
||||
"node_data": {
|
||||
"prompt": "{{开始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "start-node",
|
||||
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"startPoint": {
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 700,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 570,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "start-node_right",
|
||||
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
|
||||
},
|
||||
{
|
||||
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"startPoint": {
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1110,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1090,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
|
||||
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
|
||||
},
|
||||
{
|
||||
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1900,
|
||||
"y": 2480
|
||||
},
|
||||
{
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
|
||||
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
|
||||
},
|
||||
{
|
||||
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3200
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
|
||||
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
|
||||
},
|
||||
{
|
||||
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3970
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
|
||||
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "base-node",
|
||||
"type": "base-node",
|
||||
"x": 360,
|
||||
"y": 2810,
|
||||
"properties": {
|
||||
"config": {
|
||||
|
||||
},
|
||||
"height": 825.6,
|
||||
"stepName": "基本資訊",
|
||||
"node_data": {
|
||||
"desc": "",
|
||||
"name": "maxkbapplication",
|
||||
"prologue": "您好,我是MaxKB小助手,您可以向我提出MaxKB使用問題。\n- MaxKB主要功能有什麼?\n- MaxKB支持哪些大語言模型?\n- MaxKB支持哪些文檔類型?"
|
||||
},
|
||||
"input_field_list": [
|
||||
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "start-node",
|
||||
"type": "start-node",
|
||||
"x": 430,
|
||||
"y": 3660,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"label": "用户问题",
|
||||
"value": "question"
|
||||
}
|
||||
],
|
||||
"height": 276,
|
||||
"stepName": "開始",
|
||||
"globalFields": [
|
||||
{
|
||||
"label": "当前时间",
|
||||
"value": "time"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"type": "search-dataset-node",
|
||||
"x": 840,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "检索结果的分段列表",
|
||||
"value": "paragraph_list"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段列表",
|
||||
"value": "is_hit_handling_method_list"
|
||||
},
|
||||
{
|
||||
"label": "检索结果",
|
||||
"value": "data"
|
||||
},
|
||||
{
|
||||
"label": "满足直接回答的分段内容",
|
||||
"value": "directly_return"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 794,
|
||||
"stepName": "知識庫檢索",
|
||||
"node_data": {
|
||||
"dataset_id_list": [
|
||||
|
||||
],
|
||||
"dataset_setting": {
|
||||
"top_n": 3,
|
||||
"similarity": 0.6,
|
||||
"search_mode": "embedding",
|
||||
"max_paragraph_char_number": 5000
|
||||
},
|
||||
"question_reference_address": [
|
||||
"start-node",
|
||||
"question"
|
||||
],
|
||||
"source_dataset_id_list": [
|
||||
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"type": "condition-node",
|
||||
"x": 1490,
|
||||
"y": 3210,
|
||||
"properties": {
|
||||
"width": 600,
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "分支名称",
|
||||
"value": "branch_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 543.675,
|
||||
"stepName": "判斷器",
|
||||
"node_data": {
|
||||
"branch": [
|
||||
{
|
||||
"id": "1009",
|
||||
"type": "IF",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"is_hit_handling_method_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4908",
|
||||
"type": "ELSE IF 1",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"value": "1",
|
||||
"compare": "len_ge"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "161",
|
||||
"type": "ELSE",
|
||||
"condition": "and",
|
||||
"conditions": [
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"branch_condition_list": [
|
||||
{
|
||||
"index": 0,
|
||||
"height": 121.225,
|
||||
"id": "1009"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"height": 121.225,
|
||||
"id": "4908"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"height": 44,
|
||||
"id": "161"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"type": "reply-node",
|
||||
"x": 2170,
|
||||
"y": 2480,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 378,
|
||||
"stepName": "指定回覆",
|
||||
"node_data": {
|
||||
"fields": [
|
||||
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"directly_return"
|
||||
],
|
||||
"content": "",
|
||||
"reply_type": "referencing",
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3200,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 對話",
|
||||
"node_data": {
|
||||
"prompt": "已知資訊:\n{{知識庫檢索.data}}\n問題:\n{{開始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"type": "ai-chat-node",
|
||||
"x": 2160,
|
||||
"y": 3970,
|
||||
"properties": {
|
||||
"config": {
|
||||
"fields": [
|
||||
{
|
||||
"label": "AI 回答内容",
|
||||
"value": "answer"
|
||||
}
|
||||
]
|
||||
},
|
||||
"height": 763,
|
||||
"stepName": "AI 對話1",
|
||||
"node_data": {
|
||||
"prompt": "{{開始.question}}",
|
||||
"system": "",
|
||||
"model_id": "",
|
||||
"dialogue_number": 0,
|
||||
"is_result": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "start-node",
|
||||
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"startPoint": {
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 590,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 700,
|
||||
"y": 3660
|
||||
},
|
||||
{
|
||||
"x": 570,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 680,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "start-node_right",
|
||||
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
|
||||
},
|
||||
{
|
||||
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
|
||||
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"startPoint": {
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1000,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1110,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1090,
|
||||
"y": 3210
|
||||
},
|
||||
{
|
||||
"x": 1200,
|
||||
"y": 3210
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
|
||||
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
|
||||
},
|
||||
{
|
||||
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3073.775
|
||||
},
|
||||
{
|
||||
"x": 1900,
|
||||
"y": 2480
|
||||
},
|
||||
{
|
||||
"x": 2010,
|
||||
"y": 2480
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
|
||||
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
|
||||
},
|
||||
{
|
||||
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3203
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3200
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3200
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
|
||||
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
|
||||
},
|
||||
{
|
||||
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
|
||||
"type": "app-edge",
|
||||
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
|
||||
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
|
||||
"startPoint": {
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
"endPoint": {
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
},
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"pointsList": [
|
||||
{
|
||||
"x": 1780,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3293.6124999999997
|
||||
},
|
||||
{
|
||||
"x": 1890,
|
||||
"y": 3970
|
||||
},
|
||||
{
|
||||
"x": 2000,
|
||||
"y": 3970
|
||||
}
|
||||
],
|
||||
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
|
||||
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,256 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_step_node.py
|
||||
@date:2024/6/3 14:57
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from hashlib import sha1
|
||||
from typing import Type, Dict, List
|
||||
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||||
|
||||
from application.flow.common import Answer, NodeChunk
|
||||
from application.models import ChatRecord
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
|
||||
answer = step_variable['answer']
|
||||
yield answer
|
||||
node.answer_text = answer
|
||||
if global_variable is not None:
|
||||
for key in global_variable:
|
||||
workflow.context[key] = global_variable[key]
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def is_interrupt(node, step_variable: Dict, global_variable: Dict):
|
||||
return node.type == 'form-node' and not node.context.get('is_submit', False)
|
||||
|
||||
|
||||
class WorkFlowPostHandler:
|
||||
def __init__(self, chat_info, client_id, client_type):
|
||||
self.chat_info = chat_info
|
||||
self.client_id = client_id
|
||||
self.client_type = client_type
|
||||
|
||||
def handler(self, chat_id,
|
||||
chat_record_id,
|
||||
answer,
|
||||
workflow):
|
||||
question = workflow.params['question']
|
||||
details = workflow.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
answer_text_list = workflow.get_answer_text_list()
|
||||
answer_text = '\n\n'.join(
|
||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
if workflow.chat_record is not None:
|
||||
chat_record = workflow.chat_record
|
||||
chat_record.answer_text = answer_text
|
||||
chat_record.details = details
|
||||
chat_record.message_tokens = message_tokens
|
||||
chat_record.answer_tokens = answer_tokens
|
||||
chat_record.answer_text_list = answer_text_list
|
||||
chat_record.run_time = time.time() - workflow.context['start_time']
|
||||
else:
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
problem_text=question,
|
||||
answer_text=answer_text,
|
||||
details=details,
|
||||
message_tokens=message_tokens,
|
||||
answer_tokens=answer_tokens,
|
||||
answer_text_list=answer_text_list,
|
||||
run_time=time.time() - workflow.context['start_time'],
|
||||
index=0)
|
||||
asker = workflow.context.get('asker', None)
|
||||
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
self.chat_info, timeout=60 * 30)
|
||||
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_public_access_client = (QuerySet(ApplicationPublicAccessClient)
|
||||
.filter(client_id=self.client_id,
|
||||
application_id=self.chat_info.application.id).first())
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
application_public_access_client.save()
|
||||
|
||||
|
||||
class NodeResult:
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict,
|
||||
_write_context=write_context, _is_interrupt=is_interrupt):
|
||||
self._write_context = _write_context
|
||||
self.node_variable = node_variable
|
||||
self.workflow_variable = workflow_variable
|
||||
self._is_interrupt = _is_interrupt
|
||||
|
||||
def write_context(self, node, workflow):
|
||||
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def is_assertion_result(self):
|
||||
return 'branch_id' in self.node_variable
|
||||
|
||||
def is_interrupt_exec(self, current_node):
|
||||
"""
|
||||
是否中断执行
|
||||
@param current_node:
|
||||
@return:
|
||||
"""
|
||||
return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
|
||||
|
||||
|
||||
class ReferenceAddressSerializer(serializers.Serializer):
|
||||
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
|
||||
fields = serializers.ListField(
|
||||
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
|
||||
error_messages=ErrMessage.list("节点字段数组"))
|
||||
|
||||
|
||||
class FlowParamsSerializer(serializers.Serializer):
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list("历史对答"))
|
||||
|
||||
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
|
||||
|
||||
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
|
||||
|
||||
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
|
||||
|
||||
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出"))
|
||||
|
||||
client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
|
||||
|
||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||
|
||||
|
||||
class INode:
|
||||
view_type = 'many_view'
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, details, workflow_manage):
|
||||
pass
|
||||
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
if self.answer_text is None:
|
||||
return None
|
||||
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
|
||||
return [
|
||||
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
|
||||
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
|
||||
|
||||
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
|
||||
get_node_params=lambda node: node.properties.get('node_data')):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.status = 200
|
||||
self.err_message = ''
|
||||
self.node = node
|
||||
self.node_params = get_node_params(node)
|
||||
self.workflow_params = workflow_params
|
||||
self.workflow_manage = workflow_manage
|
||||
self.node_params_serializer = None
|
||||
self.flow_params_serializer = None
|
||||
self.context = {}
|
||||
self.answer_text = None
|
||||
self.id = node.id
|
||||
if up_node_id_list is None:
|
||||
up_node_id_list = []
|
||||
self.up_node_id_list = up_node_id_list
|
||||
self.node_chunk = NodeChunk()
|
||||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||||
"".join([*sorted(up_node_id_list),
|
||||
node.id]))),
|
||||
"utf-8")).hexdigest()
|
||||
|
||||
def valid_args(self, node_params, flow_params):
|
||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||
node_params_serializer_class = self.get_node_params_serializer_class()
|
||||
if flow_params_serializer_class is not None and flow_params is not None:
|
||||
self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
|
||||
self.flow_params_serializer.is_valid(raise_exception=True)
|
||||
if node_params_serializer_class is not None:
|
||||
self.node_params_serializer = node_params_serializer_class(data=node_params)
|
||||
self.node_params_serializer.is_valid(raise_exception=True)
|
||||
if self.node.properties.get('status', 200) != 200:
|
||||
raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用'))
|
||||
|
||||
def get_reference_field(self, fields: List[str]):
|
||||
return self.get_field(self.context, fields)
|
||||
|
||||
@staticmethod
|
||||
def get_field(obj, fields: List[str]):
|
||||
for field in fields:
|
||||
value = obj.get(field)
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
obj = value
|
||||
return obj
|
||||
|
||||
@abstractmethod
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
pass
|
||||
|
||||
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FlowParamsSerializer
|
||||
|
||||
def get_write_error_context(self, e):
|
||||
self.status = 500
|
||||
self.answer_text = str(e)
|
||||
self.err_message = str(e)
|
||||
self.context['run_time'] = time.time() - self.context['start_time']
|
||||
|
||||
def write_error_context(answer, status=200):
|
||||
pass
|
||||
|
||||
return write_error_context
|
||||
|
||||
def run(self) -> NodeResult:
|
||||
"""
|
||||
:return: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.context['start_time'] = start_time
|
||||
result = self._run()
|
||||
self.context['run_time'] = time.time() - start_time
|
||||
return result
|
||||
|
||||
def _run(self):
|
||||
result = self.execute()
|
||||
return result
|
||||
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
"""
|
||||
运行详情
|
||||
:return: 步骤详情
|
||||
"""
|
||||
return {}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
from .ai_chat_step_node import *
|
||||
from .application_node import BaseApplicationNode
|
||||
from .condition_node import *
|
||||
from .direct_reply_node import *
|
||||
from .form_node import *
|
||||
from .function_lib_node import *
|
||||
from .function_node import *
|
||||
from .question_node import *
|
||||
from .reranker_node import *
|
||||
|
||||
from .document_extract_node import *
|
||||
from .image_understand_step_node import *
|
||||
from .image_generate_step_node import *
|
||||
|
||||
from .search_dataset_node import *
|
||||
from .speech_to_text_step_node import BaseSpeechToTextNode
|
||||
from .start_node import *
|
||||
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
|
||||
from .variable_assign_node import BaseVariableAssignNode
|
||||
from .mcp_node import BaseMcpNode
|
||||
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode,
|
||||
BaseConditionNode, BaseReplyNode,
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||
BaseDocumentExtractNode,
|
||||
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
|
||||
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
find_list = [node for node in node_list if node.type == node_type]
|
||||
if len(find_list) > 0:
|
||||
return find_list[0]
|
||||
return None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:29
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_node.py
|
||||
@date:2024/6/4 13:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ChatNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char(_("Role Setting")))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||
_("Number of multi-round conversations")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
model_params_setting = serializers.DictField(required=False,
|
||||
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||
model_setting = serializers.DictField(required=False,
|
||||
error_messages=ErrMessage.dict('Model settings'))
|
||||
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char(_("Context Type")))
|
||||
mcp_enable = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean(_("Whether to enable MCP")))
|
||||
mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server")))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ChatNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
model_params_setting=None,
|
||||
dialogue_type=None,
|
||||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:34
|
||||
@desc:
|
||||
"""
|
||||
from .base_chat_node import BaseChatNode
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_question_node.py
|
||||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from functools import reduce
|
||||
from types import AsyncGeneratorType
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from application.flow.tools import Reasoning
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
tool_message_template = """
|
||||
<details>
|
||||
<summary>
|
||||
<strong>Called MCP Tool: <em>%s</em></strong>
|
||||
</summary>
|
||||
|
||||
```json
|
||||
%s
|
||||
```
|
||||
</details>
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
||||
reasoning_content: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
node.context['reasoning_content'] = reasoning_content
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
node.answer_text = answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
reasoning_content = ''
|
||||
model_setting = node.context.get('model_setting',
|
||||
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||
'reasoning_content_start': '<think>'})
|
||||
reasoning = Reasoning(model_setting.get('reasoning_content_start', '<think>'),
|
||||
model_setting.get('reasoning_content_end', '</think>'))
|
||||
response_reasoning_content = False
|
||||
|
||||
for chunk in response:
|
||||
reasoning_chunk = reasoning.get_reasoning_content(chunk)
|
||||
content_chunk = reasoning_chunk.get('content')
|
||||
if 'reasoning_content' in chunk.additional_kwargs:
|
||||
response_reasoning_content = True
|
||||
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
||||
answer += content_chunk
|
||||
if reasoning_content_chunk is None:
|
||||
reasoning_content_chunk = ''
|
||||
reasoning_content += reasoning_content_chunk
|
||||
yield {'content': content_chunk,
|
||||
'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
|
||||
False) else ''}
|
||||
|
||||
reasoning_chunk = reasoning.get_end_reasoning_content()
|
||||
answer += reasoning_chunk.get('content')
|
||||
reasoning_content_chunk = ""
|
||||
if not response_reasoning_content:
|
||||
reasoning_content_chunk = reasoning_chunk.get(
|
||||
'reasoning_content')
|
||||
yield {'content': reasoning_chunk.get('content'),
|
||||
'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
|
||||
False) else ''}
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||
|
||||
|
||||
async def _yield_mcp_response(chat_model, message_list, mcp_servers):
|
||||
async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
|
||||
agent = create_react_agent(chat_model, client.get_tools())
|
||||
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||
async for chunk in response:
|
||||
if isinstance(chunk[0], ToolMessage):
|
||||
content = tool_message_template % (chunk[0].name, chunk[0].content)
|
||||
chunk[0].content = content
|
||||
yield chunk[0]
|
||||
if isinstance(chunk[0], AIMessageChunk):
|
||||
yield chunk[0]
|
||||
|
||||
|
||||
def mcp_response_generator(chat_model, message_list, mcp_servers):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
|
||||
while True:
|
||||
try:
|
||||
chunk = loop.run_until_complete(anext_async(async_gen))
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f'exception: {e}')
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def anext_async(agen):
|
||||
return await agen.__anext__()
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
model_setting = node.context.get('model_setting',
|
||||
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||
'reasoning_content_start': '<think>'})
|
||||
reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end'))
|
||||
reasoning_result = reasoning.get_reasoning_content(response)
|
||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
if 'reasoning_content' in response.response_metadata:
|
||||
reasoning_content = response.response_metadata.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
|
||||
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
||||
|
||||
|
||||
def get_default_model_params_setting(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = credential.get_model_params_setting_form(
|
||||
model.model_name).get_default_form_data()
|
||||
return model_params_setting
|
||||
|
||||
|
||||
def get_node_message(chat_record, runtime_node_id):
|
||||
node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
|
||||
if node_details is None:
|
||||
return []
|
||||
return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]
|
||||
|
||||
|
||||
def get_workflow_message(chat_record):
|
||||
return [chat_record.get_human_message(), chat_record.get_ai_message()]
|
||||
|
||||
|
||||
def get_message(chat_record, dialogue_type, runtime_node_id):
|
||||
return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
|
||||
chat_record)
|
||||
|
||||
|
||||
class BaseChatNode(IChatNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
self.context['reasoning_content'] = details.get('reasoning_content')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting=None,
|
||||
dialogue_type=None,
|
||||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
**kwargs) -> NodeResult:
|
||||
if dialogue_type is None:
|
||||
dialogue_type = 'WORKFLOW'
|
||||
|
||||
if model_params_setting is None:
|
||||
model_params_setting = get_default_model_params_setting(model_id)
|
||||
if model_setting is None:
|
||||
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||
'reasoning_content_start': '<think>'}
|
||||
self.context['model_setting'] = model_setting
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
|
||||
self.runtime_node_id)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
system = self.workflow_manage.generate_prompt(system)
|
||||
self.context['system'] = system
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
|
||||
if mcp_enable and mcp_servers is not None:
|
||||
r = mcp_response_generator(chat_model, message_list, mcp_servers)
|
||||
return NodeResult(
|
||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
get_message(history_chat_record[index], dialogue_type, runtime_node_id)
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
for message in history_message:
|
||||
if isinstance(message.content, str):
|
||||
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
|
||||
return history_message
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
||||
if system is not None and len(system) > 0:
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
else:
|
||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.context.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'reasoning_content': self.context.get('reasoning_content'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# coding=utf-8
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
# coding=utf-8
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ApplicationNodeSerializer(serializers.Serializer):
|
||||
application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID")))
|
||||
question_reference_address = serializers.ListField(required=True,
|
||||
error_messages=ErrMessage.list(_("User Questions")))
|
||||
api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields")))
|
||||
user_input_field_list = serializers.ListField(required=False,
|
||||
error_messages=ErrMessage.uuid(_("User Input Fields")))
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio")))
|
||||
child_node = serializers.DictField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.dict(_("Child Nodes")))
|
||||
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||
|
||||
|
||||
class IApplicationNode(INode):
|
||||
type = 'application-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ApplicationNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
kwargs = {}
|
||||
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
|
||||
value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
|
||||
kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(value,
|
||||
api_input_field['value'][
|
||||
1:]) if value != '' else ''
|
||||
|
||||
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
|
||||
value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
|
||||
kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(value,
|
||||
user_input_field['value'][
|
||||
1:]) if value != '' else ''
|
||||
# 判断是否包含这个属性
|
||||
app_document_list = self.node_params_serializer.data.get('document_list', [])
|
||||
if app_document_list and len(app_document_list) > 0:
|
||||
app_document_list = self.workflow_manage.get_reference_field(
|
||||
app_document_list[0],
|
||||
app_document_list[1:])
|
||||
for document in app_document_list:
|
||||
if 'file_id' not in document:
|
||||
raise ValueError(
|
||||
_("Parameter value error: The uploaded document lacks file_id, and the document upload fails"))
|
||||
app_image_list = self.node_params_serializer.data.get('image_list', [])
|
||||
if app_image_list and len(app_image_list) > 0:
|
||||
app_image_list = self.workflow_manage.get_reference_field(
|
||||
app_image_list[0],
|
||||
app_image_list[1:])
|
||||
for image in app_image_list:
|
||||
if 'file_id' not in image:
|
||||
raise ValueError(
|
||||
_("Parameter value error: The uploaded image lacks file_id, and the image upload fails"))
|
||||
|
||||
app_audio_list = self.node_params_serializer.data.get('audio_list', [])
|
||||
if app_audio_list and len(app_audio_list) > 0:
|
||||
app_audio_list = self.workflow_manage.get_reference_field(
|
||||
app_audio_list[0],
|
||||
app_audio_list[1:])
|
||||
for audio in app_audio_list:
|
||||
if 'file_id' not in audio:
|
||||
raise ValueError(
|
||||
_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
||||
app_document_list=app_document_list, app_image_list=app_image_list,
|
||||
app_audio_list=app_audio_list,
|
||||
message=str(question), **kwargs)
|
||||
|
||||
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
|
||||
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# coding=utf-8
|
||||
from .base_application_node import BaseApplicationNode
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
# coding=utf-8
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
from application.flow.common import Answer
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.application_node.i_application_node import IApplicationNode
|
||||
from application.models import Chat
|
||||
|
||||
|
||||
def string_to_uuid(input_str):
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))
|
||||
|
||||
|
||||
def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
|
||||
return node_variable.get('is_interrupt_exec', False)
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
||||
reasoning_content: str):
|
||||
result = node_variable.get('result')
|
||||
node.context['application_node_dict'] = node_variable.get('application_node_dict')
|
||||
node.context['node_dict'] = node_variable.get('node_dict', {})
|
||||
node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec')
|
||||
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
|
||||
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
|
||||
node.context['answer'] = answer
|
||||
node.context['result'] = answer
|
||||
node.context['reasoning_content'] = reasoning_content
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
node.answer_text = answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
reasoning_content = ''
|
||||
usage = {}
|
||||
node_child_node = {}
|
||||
application_node_dict = node.context.get('application_node_dict', {})
|
||||
is_interrupt_exec = False
|
||||
for chunk in response:
|
||||
# 先把流转成字符串
|
||||
response_content = chunk.decode('utf-8')[6:]
|
||||
response_content = json.loads(response_content)
|
||||
content = response_content.get('content', '')
|
||||
runtime_node_id = response_content.get('runtime_node_id', '')
|
||||
chat_record_id = response_content.get('chat_record_id', '')
|
||||
child_node = response_content.get('child_node')
|
||||
view_type = response_content.get('view_type')
|
||||
node_type = response_content.get('node_type')
|
||||
real_node_id = response_content.get('real_node_id')
|
||||
node_is_end = response_content.get('node_is_end', False)
|
||||
_reasoning_content = response_content.get('reasoning_content', '')
|
||||
if node_type == 'form-node':
|
||||
is_interrupt_exec = True
|
||||
answer += content
|
||||
reasoning_content += _reasoning_content
|
||||
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
|
||||
'child_node': child_node}
|
||||
|
||||
if real_node_id is not None:
|
||||
application_node = application_node_dict.get(real_node_id, None)
|
||||
if application_node is None:
|
||||
|
||||
application_node_dict[real_node_id] = {'content': content,
|
||||
'runtime_node_id': runtime_node_id,
|
||||
'chat_record_id': chat_record_id,
|
||||
'child_node': child_node,
|
||||
'index': len(application_node_dict),
|
||||
'view_type': view_type,
|
||||
'reasoning_content': _reasoning_content}
|
||||
else:
|
||||
application_node['content'] += content
|
||||
application_node['reasoning_content'] += _reasoning_content
|
||||
|
||||
yield {'content': content,
|
||||
'node_type': node_type,
|
||||
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
|
||||
'reasoning_content': _reasoning_content,
|
||||
'child_node': child_node,
|
||||
'real_node_id': real_node_id,
|
||||
'node_is_end': node_is_end,
|
||||
'view_type': view_type}
|
||||
usage = response_content.get('usage', {})
|
||||
node_variable['result'] = {'usage': usage}
|
||||
node_variable['is_interrupt_exec'] = is_interrupt_exec
|
||||
node_variable['child_node'] = node_child_node
|
||||
node_variable['application_node_dict'] = application_node_dict
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result', {}).get('data', {})
|
||||
node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'),
|
||||
'prompt_tokens': response.get('prompt_tokens')}}
|
||||
answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
|
||||
reasoning_content = response.get('reasoning_content', '')
|
||||
answer_list = response.get('answer_list', [])
|
||||
node_variable['application_node_dict'] = {answer.get('real_node_id'): {**answer, 'index': index} for answer, index
|
||||
in
|
||||
zip(answer_list, range(len(answer_list)))}
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||
|
||||
|
||||
def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
|
||||
try:
|
||||
if application_node_dict is None:
|
||||
return
|
||||
for key in application_node_dict:
|
||||
application_node = application_node_dict[key]
|
||||
if application_node.get('runtime_node_id') == runtime_node_id:
|
||||
content: str = application_node.get('content')
|
||||
match = re.search('<form_rander>.*?</form_rander>', content)
|
||||
if match:
|
||||
form_setting_str = match.group().replace('<form_rander>', '').replace('</form_rander>', '')
|
||||
form_setting = json.loads(form_setting_str)
|
||||
form_setting['is_submit'] = True
|
||||
form_setting['form_data'] = node_data
|
||||
value = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
|
||||
res = re.sub('<form_rander>.*?</form_rander>',
|
||||
'${value}', content)
|
||||
application_node['content'] = res.replace('${value}', value)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class BaseApplicationNode(IApplicationNode):
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
if self.answer_text is None:
|
||||
return None
|
||||
application_node_dict = self.context.get('application_node_dict')
|
||||
if application_node_dict is None or len(application_node_dict) == 0:
|
||||
return [
|
||||
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
|
||||
self.context.get('child_node'), self.runtime_node_id, '')]
|
||||
else:
|
||||
return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
|
||||
self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
|
||||
'chat_record_id': n.get('chat_record_id')
|
||||
, 'child_node': n.get('child_node')}, n.get('real_node_id'),
|
||||
n.get('reasoning_content', ''))
|
||||
for n in
|
||||
sorted(application_node_dict.values(), key=lambda item: item.get('index'))]
|
||||
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['result'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
self.context['type'] = details.get('type')
|
||||
self.context['reasoning_content'] = details.get('reasoning_content')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
|
||||
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
|
||||
**kwargs) -> NodeResult:
|
||||
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
||||
# 生成嵌入应用的chat_id
|
||||
current_chat_id = string_to_uuid(chat_id + application_id)
|
||||
Chat.objects.get_or_create(id=current_chat_id, defaults={
|
||||
'application_id': application_id,
|
||||
'abstract': message[0:1024],
|
||||
'client_id': client_id,
|
||||
})
|
||||
if app_document_list is None:
|
||||
app_document_list = []
|
||||
if app_image_list is None:
|
||||
app_image_list = []
|
||||
if app_audio_list is None:
|
||||
app_audio_list = []
|
||||
runtime_node_id = None
|
||||
record_id = None
|
||||
child_node_value = None
|
||||
if child_node is not None:
|
||||
runtime_node_id = child_node.get('runtime_node_id')
|
||||
record_id = child_node.get('chat_record_id')
|
||||
child_node_value = child_node.get('child_node')
|
||||
application_node_dict = self.context.get('application_node_dict')
|
||||
reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
|
||||
|
||||
response = ChatMessageSerializer(
|
||||
data={'chat_id': current_chat_id, 'message': message,
|
||||
're_chat': re_chat,
|
||||
'stream': stream,
|
||||
'application_id': application_id,
|
||||
'client_id': client_id,
|
||||
'client_type': client_type,
|
||||
'document_list': app_document_list,
|
||||
'image_list': app_image_list,
|
||||
'audio_list': app_audio_list,
|
||||
'runtime_node_id': runtime_node_id,
|
||||
'chat_record_id': record_id,
|
||||
'child_node': child_node_value,
|
||||
'node_data': node_data,
|
||||
'form_data': kwargs}).chat()
|
||||
if response.status_code == 200:
|
||||
if stream:
|
||||
content_generator = response.streaming_content
|
||||
return NodeResult({'result': content_generator, 'question': message}, {},
|
||||
_write_context=write_context_stream, _is_interrupt=_is_interrupt_exec)
|
||||
else:
|
||||
data = json.loads(response.content)
|
||||
return NodeResult({'result': data, 'question': message}, {},
|
||||
_write_context=write_context, _is_interrupt=_is_interrupt_exec)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
global_fields = []
|
||||
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
|
||||
value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
|
||||
global_fields.append({
|
||||
'label': api_input_field['variable'],
|
||||
'key': api_input_field['variable'],
|
||||
'value': self.workflow_manage.get_reference_field(
|
||||
value,
|
||||
api_input_field['value'][1:]
|
||||
) if value != '' else ''
|
||||
})
|
||||
|
||||
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
|
||||
value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
|
||||
global_fields.append({
|
||||
'label': user_input_field['label'],
|
||||
'key': user_input_field['field'],
|
||||
'value': self.workflow_manage.get_reference_field(
|
||||
value,
|
||||
user_input_field['value'][1:]
|
||||
) if value != '' else ''
|
||||
})
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"info": self.node.properties.get('node_data'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'reasoning_content': self.context.get('reasoning_content'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'global_fields': global_fields,
|
||||
'document_list': self.workflow_manage.document_list,
|
||||
'image_list': self.workflow_manage.image_list,
|
||||
'audio_list': self.workflow_manage.audio_list,
|
||||
'application_node_dict': self.context.get('application_node_dict')
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from .contain_compare import *
|
||||
from .equal_compare import *
|
||||
from .ge_compare import *
|
||||
from .gt_compare import *
|
||||
from .is_not_null_compare import *
|
||||
from .is_not_true import IsNotTrueCompare
|
||||
from .is_null_compare import *
|
||||
from .is_true import IsTrueCompare
|
||||
from .le_compare import *
|
||||
from .len_equal_compare import *
|
||||
from .len_ge_compare import *
|
||||
from .len_gt_compare import *
|
||||
from .len_le_compare import *
|
||||
from .len_lt_compare import *
|
||||
from .lt_compare import *
|
||||
from .not_contain_compare import *
|
||||
|
||||
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
|
||||
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
|
||||
IsNullCompare(),
|
||||
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()]
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: compare.py
|
||||
@date:2024/6/7 14:37
|
||||
@desc:
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Compare:
|
||||
@abstractmethod
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compare(self, source_value, compare, target_value):
|
||||
pass
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: contain_compare.py
|
||||
@date:2024/6/11 10:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class ContainCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'contain':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
if isinstance(source_value, str):
|
||||
return str(target_value) in source_value
|
||||
return any([str(item) == str(target_value) for item in source_value])
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: equal_compare.py
|
||||
@date:2024/6/7 14:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class EqualCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'eq':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
return str(source_value) == str(target_value)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'ge':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) >= float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'gt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) > float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: is_not_null_compare.py
|
||||
@date:2024/6/28 10:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNotNullCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'is_not_null':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
return source_value is not None and len(source_value) > 0
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: is_not_true.py
|
||||
@date:2025/4/7 13:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNotTrueCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'is_not_true':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return source_value is False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: is_null_compare.py
|
||||
@date:2024/6/28 10:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNullCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'is_null':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
return source_value is None or len(source_value) == 0
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: IsTrue.py
|
||||
@date:2025/4/7 13:38
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsTrueCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'is_true':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return source_value is True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'le':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) <= float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: equal_compare.py
|
||||
@date:2024/6/7 14:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenEqualCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_eq':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) == int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_ge':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) >= int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_gt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) > int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_le':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) <= int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_lt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) < int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'lt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) < float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: contain_compare.py
|
||||
@date:2024/6/11 10:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class NotContainCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'not_contain':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
if isinstance(source_value, str):
|
||||
return str(target_value) not in source_value
|
||||
return not any([str(item) == str(target_value) for item in source_value])
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_condition_node.py
|
||||
@date:2024/6/7 9:54
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ConditionSerializer(serializers.Serializer):
|
||||
compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator")))
|
||||
value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value")))
|
||||
field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields")))
|
||||
|
||||
|
||||
class ConditionBranchSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id")))
|
||||
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type")))
|
||||
condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and")))
|
||||
conditions = ConditionSerializer(many=True)
|
||||
|
||||
|
||||
class ConditionNodeParamsSerializer(serializers.Serializer):
|
||||
branch = ConditionBranchSerializer(many=True)
|
||||
|
||||
|
||||
class IConditionNode(INode):
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ConditionNodeParamsSerializer
|
||||
|
||||
type = 'condition-node'
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_condition_node import BaseConditionNode
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_condition_node.py
|
||||
@date:2024/6/7 11:29
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.condition_node.compare import compare_handle_list
|
||||
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
|
||||
|
||||
|
||||
class BaseConditionNode(IConditionNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['branch_id'] = details.get('branch_id')
|
||||
self.context['branch_name'] = details.get('branch_name')
|
||||
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
branch_list = self.node_params_serializer.data['branch']
|
||||
branch = self._execute(branch_list)
|
||||
r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
|
||||
return r
|
||||
|
||||
def _execute(self, branch_list: List):
|
||||
for branch in branch_list:
|
||||
if self.branch_assertion(branch):
|
||||
return branch
|
||||
|
||||
def branch_assertion(self, branch):
|
||||
condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
|
||||
branch.get('conditions')]
|
||||
condition = branch.get('condition')
|
||||
return all(condition_list) if condition == 'and' else any(condition_list)
|
||||
|
||||
def assertion(self, field_list: List[str], compare: str, value):
|
||||
try:
|
||||
value = self.workflow_manage.generate_prompt(value)
|
||||
except Exception as e:
|
||||
pass
|
||||
field_value = None
|
||||
try:
|
||||
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
|
||||
except Exception as e:
|
||||
pass
|
||||
for compare_handler in compare_handle_list:
|
||||
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
|
||||
return compare_handler.compare(field_value, compare, value)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'branch_id': self.context.get('branch_id'),
|
||||
'branch_name': self.context.get('branch_name'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 17:50
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_reply_node.py
|
||||
@date:2024/6/11 16:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ReplyNodeParamsSerializer(serializers.Serializer):
|
||||
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type")))
|
||||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field")))
|
||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char(_("Direct answer content")))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if self.data.get('reply_type') == 'referencing':
|
||||
if 'fields' not in self.data:
|
||||
raise AppApiException(500, _("Reference field cannot be empty"))
|
||||
if len(self.data.get('fields')) < 2:
|
||||
raise AppApiException(500, _("Reference field error"))
|
||||
else:
|
||||
if 'content' not in self.data or self.data.get('content') is None:
|
||||
raise AppApiException(500, _("Content cannot be empty"))
|
||||
|
||||
|
||||
class IReplyNode(INode):
|
||||
type = 'reply-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ReplyNodeParamsSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 17:49
|
||||
@desc:
|
||||
"""
|
||||
from .base_reply_node import *
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_reply_node.py
|
||||
@date:2024/6/11 17:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
||||
|
||||
|
||||
class BaseReplyNode(IReplyNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
if reply_type == 'referencing':
|
||||
result = self.get_reference_content(fields)
|
||||
else:
|
||||
result = self.generate_reply_content(content)
|
||||
return NodeResult({'answer': result}, {})
|
||||
|
||||
def generate_reply_content(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
||||
def get_reference_content(self, fields: List[str]):
|
||||
return str(self.workflow_manage.get_reference_field(
|
||||
fields[0],
|
||||
fields[1:]))
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'answer': self.context.get('answer'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .impl import *
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class DocumentExtractNodeSerializer(serializers.Serializer):
|
||||
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||
|
||||
|
||||
class IDocumentExtractNode(INode):
|
||||
type = 'document-extract-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return DocumentExtractNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0],
|
||||
self.node_params_serializer.data.get('document_list')[1:])
|
||||
return self.execute(document=res, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, document, chat_id, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .base_document_extract_node import BaseDocumentExtractNode
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
# coding=utf-8
|
||||
import io
|
||||
import mimetypes
|
||||
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
||||
from dataset.models import File
|
||||
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
|
||||
|
||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||
content_type, _ = mimetypes.guess_type(file_name)
|
||||
if content_type is None:
|
||||
# 如果未能识别,设置为默认的二进制文件类型
|
||||
content_type = "application/octet-stream"
|
||||
# 创建一个内存中的字节流对象
|
||||
file_stream = io.BytesIO(file_bytes)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = len(file_bytes)
|
||||
|
||||
# 创建 InMemoryUploadedFile 对象
|
||||
uploaded_file = InMemoryUploadedFile(
|
||||
file=file_stream,
|
||||
field_name=None,
|
||||
name=file_name,
|
||||
content_type=content_type,
|
||||
size=file_size,
|
||||
charset=None,
|
||||
)
|
||||
return uploaded_file
|
||||
|
||||
|
||||
splitter = '\n`-----------------------------------`\n'
|
||||
|
||||
class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['content'] = details.get('content')
|
||||
|
||||
|
||||
def execute(self, document, chat_id, **kwargs):
|
||||
get_buffer = FileBufferHandle().get_buffer
|
||||
|
||||
self.context['document_list'] = document
|
||||
content = []
|
||||
if document is None or not isinstance(document, list):
|
||||
return NodeResult({'content': ''}, {})
|
||||
|
||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||
|
||||
# doc文件中的图片保存
|
||||
def save_image(image_list):
|
||||
for image in image_list:
|
||||
meta = {
|
||||
'debug': False if application.id else True,
|
||||
'chat_id': chat_id,
|
||||
'application_id': str(application.id) if application.id else None,
|
||||
'file_id': str(image.id)
|
||||
}
|
||||
file = bytes_to_uploaded_file(image.image, image.image_name)
|
||||
FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
|
||||
for doc in document:
|
||||
file = QuerySet(File).filter(id=doc['file_id']).first()
|
||||
buffer = io.BytesIO(file.get_byte().tobytes())
|
||||
buffer.name = doc['name'] # this is the important line
|
||||
|
||||
for split_handle in (parse_table_handle_list + split_handles):
|
||||
if split_handle.support(buffer, get_buffer):
|
||||
# 回到文件头
|
||||
buffer.seek(0)
|
||||
file_content = split_handle.get_content(buffer, save_image)
|
||||
content.append('### ' + doc['name'] + '\n' + file_content)
|
||||
break
|
||||
|
||||
return NodeResult({'content': splitter.join(content)}, {})
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
content = self.context.get('content', '').split(splitter)
|
||||
# 不保存content全部内容,因为content内容可能会很大
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'content': [file_content[:500] for file_content in content],
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'document_list': self.context.get('document_list')
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/11/4 14:48
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_form_node.py
|
||||
@date:2024/11/4 14:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class FormNodeParamsSerializer(serializers.Serializer):
|
||||
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration")))
|
||||
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content')))
|
||||
form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||
|
||||
|
||||
class IFormNode(INode):
|
||||
type = 'form-node'
|
||||
view_type = 'single_view'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FormNodeParamsSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/11/4 14:49
|
||||
@desc:
|
||||
"""
|
||||
from .base_form_node import BaseFormNode
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_form_node.py
|
||||
@date:2024/11/4 14:52
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow.common import Answer
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.form_node.i_form_node import IFormNode
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
|
||||
result = step_variable['result']
|
||||
yield result
|
||||
node.answer_text = result
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
class BaseFormNode(IFormNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
form_data = details.get('form_data', None)
|
||||
self.context['result'] = details.get('result')
|
||||
self.context['form_content_format'] = details.get('form_content_format')
|
||||
self.context['form_field_list'] = details.get('form_field_list')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['start_time'] = details.get('start_time')
|
||||
self.context['form_data'] = form_data
|
||||
self.context['is_submit'] = details.get('is_submit')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('result')
|
||||
if form_data is not None:
|
||||
for key in form_data:
|
||||
self.context[key] = form_data[key]
|
||||
|
||||
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
|
||||
if form_data is not None:
|
||||
self.context['is_submit'] = True
|
||||
self.context['form_data'] = form_data
|
||||
for key in form_data:
|
||||
self.context[key] = form_data.get(key)
|
||||
else:
|
||||
self.context['is_submit'] = False
|
||||
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
||||
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
||||
"is_submit": self.context.get("is_submit", False)}
|
||||
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
|
||||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return NodeResult(
|
||||
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
form_content_format = self.context.get('form_content_format')
|
||||
form_field_list = self.context.get('form_field_list')
|
||||
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
||||
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
||||
'form_data': self.context.get('form_data', {}),
|
||||
"is_submit": self.context.get("is_submit", False)}
|
||||
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
|
||||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
|
||||
self.runtime_node_id, '')]
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
form_content_format = self.context.get('form_content_format')
|
||||
form_field_list = self.context.get('form_field_list')
|
||||
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
||||
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
||||
'form_data': self.context.get('form_data', {}),
|
||||
"is_submit": self.context.get("is_submit", False)}
|
||||
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
|
||||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"result": value,
|
||||
"form_content_format": self.context.get('form_content_format'),
|
||||
"form_field_list": self.context.get('form_field_list'),
|
||||
'form_data': self.context.get('form_data'),
|
||||
'start_time': self.context.get('start_time'),
|
||||
'is_submit': self.context.get('is_submit'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/8/8 17:45
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_function_lib_node.py
|
||||
@date:2024/8/8 16:21
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.field.common import ObjectField
|
||||
from common.util.field_message import ErrMessage
|
||||
from function_lib.models.function import FunctionLib
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class InputField(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||
|
||||
|
||||
class FunctionLibNodeParamsSerializer(serializers.Serializer):
|
||||
function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID')))
|
||||
input_field_list = InputField(required=True, many=True)
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first()
|
||||
if f_lib is None:
|
||||
raise Exception(_('The function has been deleted'))
|
||||
|
||||
|
||||
class IFunctionLibNode(INode):
|
||||
type = 'function-lib-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FunctionLibNodeParamsSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/8/8 17:48
|
||||
@desc:
|
||||
"""
|
||||
from .base_function_lib_node import BaseFunctionLibNodeNode
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_function_lib_node.py
|
||||
@date:2024/8/8 17:49
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.function_code import FunctionExecutor
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from function_lib.models.function import FunctionLib
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
|
||||
result = str(step_variable['result']) + '\n'
|
||||
yield result
|
||||
node.answer_text = result
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def get_field_value(debug_field_list, name, is_required):
|
||||
result = [field for field in debug_field_list if field.get('name') == name]
|
||||
if len(result) > 0:
|
||||
return result[-1]['value']
|
||||
if is_required:
|
||||
raise AppApiException(500, _('Field: {name} No value set').format(name=name))
|
||||
return None
|
||||
|
||||
|
||||
def valid_reference_value(_type, value, name):
|
||||
if _type == 'int':
|
||||
instance_type = int | float
|
||||
elif _type == 'float':
|
||||
instance_type = float | int
|
||||
elif _type == 'dict':
|
||||
instance_type = dict
|
||||
elif _type == 'array':
|
||||
instance_type = list
|
||||
elif _type == 'string':
|
||||
instance_type = str
|
||||
else:
|
||||
raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name,
|
||||
_type=_type))
|
||||
if not isinstance(value, instance_type):
|
||||
raise Exception(
|
||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||
value=value))
|
||||
|
||||
|
||||
def convert_value(name: str, value, _type, is_required, source, node):
|
||||
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||
return None
|
||||
if not is_required and source == 'reference' and (value is None or len(value) == 0):
|
||||
return None
|
||||
if source == 'reference':
|
||||
value = node.workflow_manage.get_reference_field(
|
||||
value[0],
|
||||
value[1:])
|
||||
valid_reference_value(_type, value, name)
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
return value
|
||||
try:
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
if _type == 'dict':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
if _type == 'array':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
return value
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||
value=value))
|
||||
|
||||
|
||||
def valid_function(function_lib, user_id):
|
||||
if function_lib is None:
|
||||
raise Exception(_('Function does not exist'))
|
||||
if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id):
|
||||
raise Exception(_('No permission to use this function {name}').format(name=function_lib.name))
|
||||
if not function_lib.is_active:
|
||||
raise Exception(_('Function {name} is unavailable').format(name=function_lib.name))
|
||||
|
||||
|
||||
class BaseFunctionLibNodeNode(IFunctionLibNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
if self.node_params.get('is_result'):
|
||||
self.answer_text = str(details.get('result'))
|
||||
|
||||
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
|
||||
valid_function(function_lib, self.flow_params_serializer.data.get('user_id'))
|
||||
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
|
||||
field.get('is_required'),
|
||||
field.get('source'), self)
|
||||
for field in
|
||||
[{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'),
|
||||
), **field}
|
||||
for field in
|
||||
function_lib.input_field_list]}
|
||||
|
||||
self.context['params'] = params
|
||||
# 合并初始化参数
|
||||
if function_lib.init_params is not None:
|
||||
all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params
|
||||
else:
|
||||
all_params = params
|
||||
result = function_executor.exec_code(function_lib.code, all_params)
|
||||
return NodeResult({'result': result}, {}, _write_context=write_context)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"result": self.context.get('result'),
|
||||
"params": self.context.get('params'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/13 10:43
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_function_lib_node.py
|
||||
@date:2024/8/8 16:21
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from django.core import validators
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.field.common import ObjectField
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.utils.formatting import lazy_format
|
||||
|
||||
|
||||
class InputField(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||
is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required")))
|
||||
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
|
||||
message=_("The field only supports string|int|dict|array|float"), code=500)
|
||||
])
|
||||
source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^custom|reference$"),
|
||||
message=_("The field only supports custom|reference"), code=500)
|
||||
])
|
||||
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
is_required = self.data.get('is_required')
|
||||
if is_required and self.data.get('value') is None:
|
||||
message = lazy_format(_('{field}, this field is required.'), field=self.data.get("name"))
|
||||
raise AppApiException(500, message)
|
||||
|
||||
|
||||
class FunctionNodeParamsSerializer(serializers.Serializer):
|
||||
input_field_list = InputField(required=True, many=True)
|
||||
code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function")))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class IFunctionNode(INode):
|
||||
type = 'function-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FunctionNodeParamsSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, input_field_list, code, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/13 11:19
|
||||
@desc:
|
||||
"""
|
||||
from .base_function_node import BaseFunctionNodeNode
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_function_lib_node.py
|
||||
@date:2024/8/8 17:49
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.function_node.i_function_node import IFunctionNode
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.function_code import FunctionExecutor
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
|
||||
result = str(step_variable['result']) + '\n'
|
||||
yield result
|
||||
node.answer_text = result
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def valid_reference_value(_type, value, name):
|
||||
if _type == 'int':
|
||||
instance_type = int | float
|
||||
elif _type == 'float':
|
||||
instance_type = float | int
|
||||
elif _type == 'dict':
|
||||
instance_type = dict
|
||||
elif _type == 'array':
|
||||
instance_type = list
|
||||
elif _type == 'string':
|
||||
instance_type = str
|
||||
else:
|
||||
raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型')
|
||||
if not isinstance(value, instance_type):
|
||||
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||
|
||||
|
||||
def convert_value(name: str, value, _type, is_required, source, node):
|
||||
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||
return None
|
||||
if source == 'reference':
|
||||
value = node.workflow_manage.get_reference_field(
|
||||
value[0],
|
||||
value[1:])
|
||||
valid_reference_value(_type, value, name)
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
return value
|
||||
try:
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
if _type == 'dict':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
raise Exception("类型错误")
|
||||
if _type == 'array':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
raise Exception("类型错误")
|
||||
return value
|
||||
except Exception as e:
|
||||
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||
|
||||
|
||||
class BaseFunctionNodeNode(IFunctionNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = str(details.get('result'))
|
||||
|
||||
def execute(self, input_field_list, code, **kwargs) -> NodeResult:
|
||||
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
|
||||
field.get('is_required'), field.get('source'), self)
|
||||
for field in input_field_list}
|
||||
result = function_executor.exec_code(code, params)
|
||||
self.context['params'] = params
|
||||
return NodeResult({'result': result}, {}, _write_context=write_context)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"result": self.context.get('result'),
|
||||
"params": self.context.get('params'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ImageGenerateNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)")))
|
||||
|
||||
negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")),
|
||||
allow_null=True, allow_blank=True, )
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=False, default=0,
|
||||
error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
|
||||
dialogue_type = serializers.CharField(required=False, default='NODE',
|
||||
error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||
|
||||
|
||||
class IImageGenerateNode(INode):
|
||||
type = 'image-generate-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ImageGenerateNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_image_generate_node import BaseImageGenerateNode
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# coding=utf-8
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
class BaseImageGenerateNode(IImageGenerateNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
print(model_params_setting)
|
||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question
|
||||
message_list = self.generate_message_list(question, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
print(message_list)
|
||||
image_urls = tti_model.generate_image(question, negative_prompt)
|
||||
# 保存图片
|
||||
file_urls = []
|
||||
for image_url in image_urls:
|
||||
file_name = 'generated_image.png'
|
||||
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
|
||||
meta = {
|
||||
'debug': False if application.id else True,
|
||||
'chat_id': chat_id,
|
||||
'application_id': str(application.id) if application.id else None,
|
||||
}
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
|
||||
answer = ' '.join([f"" for path in file_urls])
|
||||
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
|
||||
'image': [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls],
|
||||
'history_message': history_message, 'question': question}, {})
|
||||
|
||||
def generate_history_ai_message(self, chat_record):
|
||||
for val in chat_record.details.values():
|
||||
if self.node.id == val['node_id'] and 'image_list' in val:
|
||||
if val['dialogue_type'] == 'WORKFLOW':
|
||||
return chat_record.get_ai_message()
|
||||
image_list = val['image_list']
|
||||
return AIMessage(content=[
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
|
||||
])
|
||||
return chat_record.get_ai_message()
|
||||
|
||||
def get_history_message(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[self.generate_history_human_message(history_chat_record[index]),
|
||||
self.generate_history_ai_message(history_chat_record[index])]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_human_message(self, chat_record):
|
||||
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
return HumanMessage(content=data['question'])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
||||
def generate_message_list(self, question: str, history_message):
|
||||
return [
|
||||
*history_message,
|
||||
question
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list'),
|
||||
'dialogue_type': self.context.get('dialogue_type')
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char(_("Role Setting")))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
|
||||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||
|
||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||
|
||||
|
||||
class IImageUnderstandNode(INode):
|
||||
type = 'image-understand-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ImageUnderstandNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
|
||||
self.node_params_serializer.data.get('image_list')[1:])
|
||||
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_image_understand_node import BaseImageUnderstandNode
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||
from dataset.models import File
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
from imghdr import what
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
node.answer_text = answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
yield chunk.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = response.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def file_id_to_base64(file_id: str):
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
file_bytes = file.get_byte()
|
||||
base64_image = base64.b64encode(file_bytes).decode("utf-8")
|
||||
return [base64_image, what(None, file_bytes.tobytes())]
|
||||
|
||||
|
||||
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
# 处理不正确的参数
|
||||
if image is None or not isinstance(image, list):
|
||||
image = []
|
||||
print(model_params_setting)
|
||||
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
|
||||
# 执行详情中的历史消息不需要图片内容
|
||||
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
# 生成消息列表, 真实的history_message
|
||||
message_list = self.generate_message_list(image_model, system, prompt,
|
||||
self.get_history_message(history_chat_record, dialogue_number), image)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['image_list'] = image
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
if stream:
|
||||
r = image_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = image_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
def get_history_message_for_details(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[self.generate_history_human_message_for_details(history_chat_record[index]),
|
||||
self.generate_history_ai_message(history_chat_record[index])]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_ai_message(self, chat_record):
|
||||
for val in chat_record.details.values():
|
||||
if self.node.id == val['node_id'] and 'image_list' in val:
|
||||
if val['dialogue_type'] == 'WORKFLOW':
|
||||
return chat_record.get_ai_message()
|
||||
return AIMessage(content=val['answer'])
|
||||
return chat_record.get_ai_message()
|
||||
|
||||
def generate_history_human_message_for_details(self, chat_record):
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = [image.get('file_id') for image in image_list]
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def get_history_message(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[self.generate_history_human_message(history_chat_record[index]),
|
||||
self.generate_history_ai_message(history_chat_record[index])]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_human_message(self, chat_record):
|
||||
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
|
||||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||
base64_image in image_base64_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
if image is not None and len(image) > 0:
|
||||
# 处理多张图片
|
||||
images = []
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_byte()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes.tobytes())
|
||||
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
messages = [HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||
*images
|
||||
])]
|
||||
else:
|
||||
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
if system is not None and len(system) > 0:
|
||||
return [
|
||||
SystemMessage(self.workflow_manage.generate_prompt(system)),
|
||||
*history_message,
|
||||
*messages
|
||||
]
|
||||
else:
|
||||
return [
|
||||
*history_message,
|
||||
*messages
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list'),
|
||||
'dialogue_type': self.context.get('dialogue_type')
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class McpNodeSerializer(serializers.Serializer):
|
||||
mcp_servers = serializers.JSONField(required=True,
|
||||
error_messages=ErrMessage.char(_("Mcp servers")))
|
||||
|
||||
mcp_server = serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char(_("Mcp server")))
|
||||
|
||||
mcp_tool = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Mcp tool")))
|
||||
|
||||
tool_params = serializers.DictField(required=True,
|
||||
error_messages=ErrMessage.char(_("Tool parameters")))
|
||||
|
||||
|
||||
class IMcpNode(INode):
|
||||
type = 'mcp-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return McpNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_mcp_node import BaseMcpNode
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# coding=utf-8
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode
|
||||
|
||||
|
||||
class BaseMcpNode(IMcpNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
self.context['tool_params'] = details.get('tool_params')
|
||||
self.context['mcp_tool'] = details.get('mcp_tool')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('result')
|
||||
|
||||
def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
|
||||
servers = json.loads(mcp_servers)
|
||||
params = json.loads(json.dumps(tool_params))
|
||||
params = self.handle_variables(params)
|
||||
|
||||
async def call_tool(s, session, t, a):
|
||||
async with MultiServerMCPClient(s) as client:
|
||||
s = await client.sessions[session].call_tool(t, a)
|
||||
return s
|
||||
|
||||
res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params))
|
||||
return NodeResult(
|
||||
{'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {})
|
||||
|
||||
def handle_variables(self, tool_params):
|
||||
# 处理参数中的变量
|
||||
for k, v in tool_params.items():
|
||||
if type(v) == str:
|
||||
tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k])
|
||||
if type(v) == dict:
|
||||
self.handle_variables(v)
|
||||
if (type(v) == list) and (type(v[0]) == str):
|
||||
tool_params[k] = self.get_reference_content(v)
|
||||
return tool_params
|
||||
|
||||
def get_reference_content(self, fields: List[str]):
|
||||
return str(self.workflow_manage.get_reference_field(
|
||||
fields[0],
|
||||
fields[1:]))
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'type': self.node.type,
|
||||
'mcp_tool': self.context.get('mcp_tool'),
|
||||
'tool_params': self.context.get('tool_params'),
|
||||
'result': self.context.get('result'),
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:30
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_node.py
|
||||
@date:2024/6/4 13:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class QuestionNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char(_("Role Setting")))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer(_("Model parameter settings")))
|
||||
|
||||
|
||||
class IQuestionNode(INode):
|
||||
type = 'question-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return QuestionNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_question_node import BaseQuestionNode
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_question_node.py
|
||||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
node.answer_text = answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
yield chunk.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = response.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def get_default_model_params_setting(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = credential.get_model_params_setting_form(
|
||||
model.model_name).get_default_form_data()
|
||||
return model_params_setting
|
||||
|
||||
|
||||
class BaseQuestionNode(IQuestionNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['question'] = details.get('question')
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['message_tokens'] = details.get('message_tokens')
|
||||
self.context['answer_tokens'] = details.get('answer_tokens')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting=None,
|
||||
**kwargs) -> NodeResult:
|
||||
if model_params_setting is None:
|
||||
model_params_setting = get_default_model_params_setting(model_id)
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
system = self.workflow_manage.generate_prompt(system)
|
||||
self.context['system'] = system
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
for message in history_message:
|
||||
if isinstance(message.content, str):
|
||||
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
|
||||
return history_message
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
||||
if system is None or len(system) == 0:
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
else:
|
||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.context.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/9/4 11:37
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_reranker_node.py
|
||||
@date:2024/9/4 10:40
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class RerankerSettingSerializer(serializers.Serializer):
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer(_("Reference segment number")))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
|
||||
error_messages=ErrMessage.float(_("Reference segment number")))
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
|
||||
|
||||
|
||||
class RerankerStepNodeSerializer(serializers.Serializer):
|
||||
reranker_setting = RerankerSettingSerializer(required=True)
|
||||
|
||||
question_reference_address = serializers.ListField(required=True)
|
||||
reranker_model_id = serializers.UUIDField(required=True)
|
||||
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class IRerankerNode(INode):
|
||||
type = 'reranker-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return RerankerStepNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
reranker_list = [self.workflow_manage.get_reference_field(
|
||||
reference[0],
|
||||
reference[1:]) for reference in
|
||||
self.node_params_serializer.data.get('reranker_reference_list')]
|
||||
return self.execute(**self.node_params_serializer.data, question=str(question),
|
||||
|
||||
reranker_list=reranker_list)
|
||||
|
||||
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/9/4 11:39
|
||||
@desc:
|
||||
"""
|
||||
from .base_reranker_node import *
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_reranker_node.py
|
||||
@date:2024/9/4 11:41
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def merge_reranker_list(reranker_list, result=None):
|
||||
if result is None:
|
||||
result = []
|
||||
for document in reranker_list:
|
||||
if isinstance(document, list):
|
||||
merge_reranker_list(document, result)
|
||||
elif isinstance(document, dict):
|
||||
content = document.get('title', '') + document.get('content', '')
|
||||
title = document.get("title")
|
||||
dataset_name = document.get("dataset_name")
|
||||
document_name = document.get('document_name')
|
||||
result.append(
|
||||
Document(page_content=str(document) if len(content) == 0 else content,
|
||||
metadata={'title': title, 'dataset_name': dataset_name, 'document_name': document_name}))
|
||||
else:
|
||||
result.append(Document(page_content=str(document), metadata={}))
|
||||
return result
|
||||
|
||||
|
||||
def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
|
||||
use_len = 0
|
||||
result = []
|
||||
for index in range(len(document_list)):
|
||||
document = document_list[index]
|
||||
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
|
||||
'relevance_score') < similarity:
|
||||
break
|
||||
content = document.page_content[0:max_paragraph_char_number - use_len]
|
||||
use_len = use_len + len(content)
|
||||
result.append({'page_content': content, 'metadata': document.metadata})
|
||||
return result
|
||||
|
||||
|
||||
def reset_result_list(result_list: List[Document], document_list: List[Document]):
|
||||
r = []
|
||||
document_list = document_list.copy()
|
||||
for result in result_list:
|
||||
filter_result_list = [document for document in document_list if document.page_content == result.page_content]
|
||||
if len(filter_result_list) > 0:
|
||||
item = filter_result_list[0]
|
||||
document_list.remove(item)
|
||||
r.append(Document(page_content=item.page_content,
|
||||
metadata={**item.metadata, 'relevance_score': result.metadata.get('relevance_score')}))
|
||||
else:
|
||||
r.append(result)
|
||||
return r
|
||||
|
||||
|
||||
class BaseRerankerNode(IRerankerNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['document_list'] = details.get('document_list', [])
|
||||
self.context['question'] = details.get('question')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['result_list'] = details.get('result_list')
|
||||
self.context['result'] = details.get('result')
|
||||
|
||||
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
|
||||
**kwargs) -> NodeResult:
|
||||
documents = merge_reranker_list(reranker_list)
|
||||
top_n = reranker_setting.get('top_n', 3)
|
||||
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
|
||||
document in documents]
|
||||
self.context['question'] = question
|
||||
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
|
||||
self.flow_params_serializer.data.get('user_id'),
|
||||
top_n=top_n)
|
||||
result = reranker_model.compress_documents(
|
||||
documents,
|
||||
question)
|
||||
similarity = reranker_setting.get('similarity', 0.6)
|
||||
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
|
||||
result = reset_result_list(result, documents)
|
||||
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
|
||||
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'document_list': self.context.get('document_list'),
|
||||
"question": self.context.get('question'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
|
||||
'result_list': self.context.get('result_list'),
|
||||
'result': self.context.get('result'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:30
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_search_dataset_node.py
|
||||
@date:2024/6/3 17:52
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from django.core import validators
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.common import flat_map
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class DatasetSettingSerializer(serializers.Serializer):
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer(_("Reference segment number")))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
|
||||
error_messages=ErrMessage.float(_('similarity')))
|
||||
search_mode = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message=_("The type only supports embedding|keywords|blend"), code=500)
|
||||
], error_messages=ErrMessage.char(_("Retrieval Mode")))
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
|
||||
|
||||
|
||||
class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list(_("Dataset id list")))
|
||||
dataset_setting = DatasetSettingSerializer(required=True)
|
||||
|
||||
question_reference_address = serializers.ListField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
def get_paragraph_list(chat_record, node_id):
|
||||
return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if
|
||||
(chat_record.details[
|
||||
key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get(
|
||||
'paragraph_list', []) is not None and key == node_id])
|
||||
|
||||
|
||||
class ISearchDatasetStepNode(INode):
|
||||
type = 'search-dataset-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return SearchDatasetStepNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
exclude_paragraph_id_list = []
|
||||
if self.flow_params_serializer.data.get('re_chat', False):
|
||||
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
|
||||
paragraph_id_list = [p.get('id') for p in flat_map(
|
||||
[get_paragraph_list(chat_record, self.runtime_node_id) for chat_record in history_chat_record if
|
||||
chat_record.problem_text == question])]
|
||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||
|
||||
return self.execute(**self.node_params_serializer.data, question=str(question),
|
||||
exclude_paragraph_id_list=exclude_paragraph_id_list)
|
||||
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_search_dataset_node import BaseSearchDatasetNode
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_search_dataset_node.py
|
||||
@date:2024/6/4 11:56
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.db import connection
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("关联知识库的向量模型不一致,无法召回分段。")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
def get_none_result(question):
|
||||
return NodeResult(
|
||||
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
|
||||
'directly_return': ''}, {})
|
||||
|
||||
|
||||
def reset_title(title):
|
||||
if title is None or len(title.strip()) == 0:
|
||||
return ""
|
||||
else:
|
||||
return f"#### {title}\n"
|
||||
|
||||
|
||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
result = details.get('paragraph_list', [])
|
||||
dataset_setting = self.node_params_serializer.data.get('dataset_setting')
|
||||
directly_return = '\n'.join(
|
||||
[f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if
|
||||
paragraph.get('is_hit_handling_method')])
|
||||
self.context['paragraph_list'] = result
|
||||
self.context['question'] = details.get('question')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')]
|
||||
self.context['data'] = '\n'.join(
|
||||
[f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in
|
||||
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)]
|
||||
self.context['directly_return'] = directly_return
|
||||
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
self.context['question'] = question
|
||||
if len(dataset_id_list) == 0:
|
||||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||
# 手动关闭数据库连接
|
||||
connection.close()
|
||||
if embedding_list is None:
|
||||
return get_none_result(question)
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
|
||||
return NodeResult({'paragraph_list': result,
|
||||
'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')],
|
||||
'data': '\n'.join(
|
||||
[f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in
|
||||
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)],
|
||||
'directly_return': '\n'.join(
|
||||
[paragraph.get('content') for paragraph in
|
||||
result if
|
||||
paragraph.get('is_hit_handling_method')]),
|
||||
'question': question},
|
||||
|
||||
{})
|
||||
|
||||
@staticmethod
|
||||
def reset_paragraph(paragraph: Dict, embedding_list: List):
|
||||
filter_embedding_list = [embedding for embedding in embedding_list if
|
||||
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
|
||||
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
|
||||
find_embedding = filter_embedding_list[-1]
|
||||
return {
|
||||
**paragraph,
|
||||
'similarity': find_embedding.get('similarity'),
|
||||
'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
|
||||
'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return',
|
||||
'update_time': paragraph.get('update_time').strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'create_time': paragraph.get('create_time').strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'id': str(paragraph.get('id')),
|
||||
'dataset_id': str(paragraph.get('dataset_id')),
|
||||
'document_id': str(paragraph.get('document_id'))
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_paragraph(embedding_list: List, vector):
|
||||
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
|
||||
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||
return []
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
exist_paragraph_list = [row.get('id') for row in paragraph_list]
|
||||
for paragraph_id in paragraph_id_list:
|
||||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
return paragraph_list
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
'question': self.context.get('question'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'paragraph_list': self.context.get('paragraph_list'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class SpeechToTextNodeSerializer(serializers.Serializer):
|
||||
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
audio_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("The audio file cannot be empty")))
|
||||
|
||||
|
||||
class ISpeechToTextNode(INode):
|
||||
type = 'speech-to-text-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return SpeechToTextNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
|
||||
self.node_params_serializer.data.get('audio_list')[1:])
|
||||
for audio in res:
|
||||
if 'file_id' not in audio:
|
||||
raise ValueError(_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails"))
|
||||
|
||||
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, stt_model_id, chat_id,
|
||||
audio,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue