feat: application flow (#3152)

v3.2
shaohuzhang1 2025-05-27 18:24:28 +08:00 committed by GitHub
parent 0c9d8ccf71
commit 896fb5fa52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
122 changed files with 8221 additions and 58 deletions

View File

@ -0,0 +1,157 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file I_base_chat_pipeline.py
@date2024/1/9 17:25
@desc:
"""
import time
from abc import abstractmethod
from typing import Type
from rest_framework import serializers
from dataset.models import Paragraph
class ParagraphPipelineModel:
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
self.content = content
self.title = title
self.status = status,
self.is_active = is_active
self.comprehensive_score = comprehensive_score
self.similarity = similarity
self.dataset_name = dataset_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method
self.directly_return_similarity = directly_return_similarity
self.meta = meta
def to_dict(self):
return {
'id': self.id,
'document_id': self.document_id,
'dataset_id': self.dataset_id,
'content': self.content,
'title': self.title,
'status': self.status,
'is_active': self.is_active,
'comprehensive_score': self.comprehensive_score,
'similarity': self.similarity,
'dataset_name': self.dataset_name,
'document_name': self.document_name,
'meta': self.meta,
}
class builder:
def __init__(self):
self.similarity = None
self.paragraph = {}
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
self.hit_handling_method = None
self.directly_return_similarity = 0.9
self.meta = {}
def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id,
'dataset_id': paragraph.dataset_id,
'content': paragraph.content,
'title': paragraph.title,
'status': paragraph.status,
'is_active': paragraph.is_active,
}
else:
self.paragraph = paragraph
return self
def add_dataset_name(self, dataset_name):
self.dataset_name = dataset_name
return self
def add_document_name(self, document_name):
self.document_name = document_name
return self
def add_hit_handling_method(self, hit_handling_method):
self.hit_handling_method = hit_handling_method
return self
def add_directly_return_similarity(self, directly_return_similarity):
self.directly_return_similarity = directly_return_similarity
return self
def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self
def add_similarity(self, similarity: float):
self.similarity = similarity
return self
def add_meta(self, meta: dict):
self.meta = meta
return self
def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
str(self.paragraph.get('dataset_id')),
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.document_name, self.hit_handling_method, self.directly_return_similarity,
self.meta)
class IBaseChatPipelineStep:
def __init__(self):
# 当前步骤上下文,用于存储当前步骤信息
self.context = {}
@abstractmethod
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
pass
def valid_args(self, manage):
step_serializer_clazz = self.get_step_serializer(manage)
step_serializer = step_serializer_clazz(data=manage.context)
step_serializer.is_valid(raise_exception=True)
self.context['step_args'] = step_serializer.data
def run(self, manage):
"""
:param manage: 步骤管理器
:return: 执行结果
"""
start_time = time.time()
self.context['start_time'] = start_time
# 校验参数,
self.valid_args(manage)
self._run(manage)
self.context['run_time'] = time.time() - start_time
def _run(self, manage):
pass
def execute(self, **kwargs):
pass
def get_details(self, manage, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 17:23
@desc:
"""

View File

@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pipeline_manage.py
@date2024/1/9 17:40
@desc:
"""
import time
from functools import reduce
from typing import List, Type, Dict
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
class PipelineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
base_to_response: BaseToResponse = SystemToResponse()):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
self.base_to_response = base_to_response
def run(self, context: Dict = None):
self.context['start_time'] = time.time()
if context is not None:
for key, value in context.items():
self.context[key] = value
for step in self.step_list:
step.run(self)
def get_details(self):
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
filter(lambda r: r is not None,
[row.get_details(self) for row in self.step_list])], {})
def get_base_to_response(self):
return self.base_to_response
class builder:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
self.base_to_response = SystemToResponse()
def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
return self
def add_base_to_response(self, base_to_response: BaseToResponse):
self.base_to_response = base_to_response
return self
def build(self):
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,110 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_step.py
@date2024/1/9 18:17
@desc: 对话
"""
from abc import abstractmethod
from typing import Type, List
from django.utils.translation import gettext_lazy as _
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
class ModelField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseChatModel):
self.fail(_('Model type error'), value=data)
return data
def to_representation(self, value):
return value
class MessageField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseMessage):
self.fail(_('Message type error'), value=data)
return data
def to_representation(self, value):
return value
class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
pass
class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list(_("Conversation list")))
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
# 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
# 对话id
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
# 用户问题
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler,
error_messages=ErrMessage.base(_("Post-processor")))
# 补全问题
padding_problem_text = serializers.CharField(required=False,
error_messages=ErrMessage.base(_("Completion Question")))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True,
error_messages=ErrMessage.base(_("No reference segment settings")))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
model_setting = serializers.DictField(required=True, allow_null=True,
error_messages=ErrMessage.dict(_("Model settings")))
model_params_setting = serializers.DictField(required=False, allow_null=True,
error_messages=ErrMessage.dict(_("Model parameter settings")))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
for message in message_list:
if not isinstance(message, BaseMessage):
raise Exception(_("message type error"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PipelineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result
@abstractmethod
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass

View File

@ -0,0 +1,334 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_chat_step.py
@date2024/1/9 18:25
@desc: 对话step Base实现
"""
import logging
import time
import traceback
import uuid
from typing import List
from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from django.utils.translation import gettext as _
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import HumanMessage, AIMessage
from langchain_core.messages import AIMessageChunk
from rest_framework import status
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from setting.models_provider.tools import get_model_instance_by_model_user_id
def add_access_num(client_id=None, client_type=None, application_id=None):
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
application_id=application_id)
.first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
def write_context(step, manage, request_token, response_token, all_text):
step.context['message_tokens'] = request_token
step.context['answer_tokens'] = response_token
current_time = time.time()
step.context['answer_text'] = all_text
step.context['run_time'] = current_time - step.context['start_time']
manage.context['run_time'] = current_time - manage.context['start_time']
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
def event_content(response,
chat_id,
chat_record_id,
paragraph_list: List[ParagraphPipelineModel],
post_response_handler: PostResponseHandler,
manage,
step,
chat_model,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None,
client_id=None, client_type=None,
is_ai_chat: bool = None,
model_setting=None):
if model_setting is None:
model_setting = {}
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
all_text = ''
reasoning_content = ''
try:
response_reasoning_content = False
for chunk in response:
reasoning_chunk = reasoning.get_reasoning_content(chunk)
content_chunk = reasoning_chunk.get('content')
if 'reasoning_content' in chunk.additional_kwargs:
response_reasoning_content = True
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
all_text += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
reasoning_content += reasoning_content_chunk
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], content_chunk,
False,
0, 0, {'node_is_end': False,
'view_type': 'many_view',
'node_type': 'ai-chat-node',
'real_node_id': 'ai-chat-node',
'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
reasoning_chunk = reasoning.get_end_reasoning_content()
all_text += reasoning_chunk.get('content')
reasoning_content_chunk = ""
if not response_reasoning_content:
reasoning_content_chunk = reasoning_chunk.get(
'reasoning_content')
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], reasoning_chunk.get('content'),
False,
0, 0, {'node_is_end': False,
'view_type': 'many_view',
'node_type': 'ai-chat-node',
'real_node_id': 'ai-chat-node',
'reasoning_content'
: reasoning_content_chunk if reasoning_content_enable else ''})
# 获取token
if is_ai_chat:
try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
except Exception as e:
request_token = 0
response_token = 0
else:
request_token = 0
response_token = 0
write_context(step, manage, request_token, response_token, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else ''
, asker=asker)
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True,
request_token, response_token,
{'node_is_end': True, 'view_type': 'many_view',
'node_type': 'ai-chat-node'})
add_access_num(client_id, client_type, manage.context.get('application_id'))
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
all_text = 'Exception:' + str(e)
write_context(step, manage, 0, 0, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
asker=asker)
add_access_num(client_id, client_type, manage.context.get('application_id'))
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], all_text,
False,
0, 0, {'node_is_end': False,
'view_type': 'many_view',
'node_type': 'ai-chat-node',
'real_node_id': 'ai-chat-node',
'reasoning_content': ''})
class BaseChatStep(IChatStep):
def execute(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
no_references_setting=None,
model_params_setting=None,
model_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
**model_params_setting) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)
def get_details(self, manage, **kwargs):
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']),
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
self.context['answer_text']),
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'cost': 0,
}
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
@staticmethod
def get_stream_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
problem_text=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if (
paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return iter(
[AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False
if chat_model is None:
return iter([AIMessageChunk(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
else:
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None,
no_references_setting=None,
model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
@staticmethod
def get_block_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
problem_text=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if (
paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return directly_return_chunk_list[0], False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False
if chat_model is None:
return AIMessage(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
else:
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None,
model_setting=None):
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
chat_record_id = uuid.uuid1()
# 调用模型
try:
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
else:
request_token = 0
response_token = 0
write_context(self, manage, request_token, response_token, chat_result.content)
reasoning_result = reasoning.get_reasoning_content(chat_result)
reasoning_result_end = reasoning.get_end_reasoning_content()
content = reasoning_result.get('content') + reasoning_result_end.get('content')
if 'reasoning_content' in chat_result.response_metadata:
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
else:
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
'reasoning_content')
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
content, manage, self, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else '',
asker=asker)
add_access_num(client_id, client_type, manage.context.get('application_id'))
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
content, True,
request_token, response_token,
{
'reasoning_content': reasoning_content if reasoning_content_enable else '',
'answer_list': [{
'content': content,
'reasoning_content': reasoning_content if reasoning_content_enable else ''
}]})
except Exception as e:
all_text = 'Exception:' + str(e)
write_context(self, manage, 0, 0, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, self, padding_problem_text, client_id, reasoning_content='',
asker=asker)
add_access_num(client_id, client_type, manage.context.get('application_id'))
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,81 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_generate_human_message_step.py
@date2024/1/9 18:15
@desc: 生成对话模板
"""
from abc import abstractmethod
from typing import Type, List
from django.utils.translation import gettext_lazy as _
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.models import ChatRecord
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
# 段落列表
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
error_messages=ErrMessage.list(_("Paragraph List")))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list(_("History Questions")))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
# 最大携带知识库段落长度
max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
_("Maximum length of the knowledge base paragraph")))
# 模板
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char(_("System prompt words (role)")))
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem")))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PipelineManage):
message_list = self.execute(**self.context['step_args'])
manage.context['message_list'] = message_list
@abstractmethod
def execute(self,
problem_text: str,
paragraph_list: List[ParagraphPipelineModel],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
"""
:param problem_text: 原始问题文本
:param paragraph_list: 段落列表
:param history_chat_record: 历史对话记录
:param dialogue_number: 多轮对话数量
:param max_paragraph_char_number: 最大段落长度
:param prompt: 模板
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:param no_references_setting: 无引用分段设置
:param system 系统提示称
:return:
"""
pass

View File

@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_generate_human_message_step.py.py
@date2024/1/10 17:50
@desc:
"""
from typing import List, Dict
from langchain.schema import BaseMessage, HumanMessage
from langchain_core.messages import SystemMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
from common.util.split_model import flat_map
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
def execute(self, problem_text: str,
paragraph_list: List[ParagraphPipelineModel],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
'value')
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
if system is not None and len(system) > 0:
return [SystemMessage(system), *flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]
return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]
@staticmethod
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[ParagraphPipelineModel],
no_references_setting: Dict):
if paragraph_list is None or len(paragraph_list) == 0:
if no_references_setting.get('status') == 'ai_questioning':
return HumanMessage(
content=no_references_setting.get('value').replace('{question}', problem))
else:
return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem))
temp_data = ""
data_list = []
for p in paragraph_list:
content = f"{p.title}:{p.content}"
temp_data += content
if len(temp_data) > max_paragraph_char_number:
row_data = content[0:max_paragraph_char_number - len(temp_data)]
data_list.append(f"<data>{row_data}</data>")
break
else:
data_list.append(f"<data>{content}</data>")
data = "\n".join(data_list)
return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem))

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_reset_problem_step.py
@date2024/1/9 18:12
@desc: 重写处理问题
"""
from abc import abstractmethod
from typing import Type, List
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.models import ChatRecord
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
class IResetProblemStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题文本
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question")))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list(_("History Questions")))
# 大语言模型
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char(
_("Question completion prompt")))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PipelineManage):
padding_problem = self.execute(**self.context.get('step_args'))
# 用户输入问题
source_problem_text = self.context.get('step_args').get('problem_text')
self.context['problem_text'] = source_problem_text
self.context['padding_problem_text'] = padding_problem
manage.context['problem_text'] = source_problem_text
manage.context['padding_problem_text'] = padding_problem
# 累加tokens
manage.context['message_tokens'] = manage.context.get('message_tokens', 0) + self.context.get('message_tokens',
0)
manage.context['answer_tokens'] = manage.context.get('answer_tokens', 0) + self.context.get('answer_tokens', 0)
@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
**kwargs):
pass

View File

@ -0,0 +1,68 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_reset_problem_step.py
@date2024/1/10 14:35
@desc:
"""
from typing import List
from django.utils.translation import gettext as _
from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from setting.models_provider.tools import get_model_instance_by_model_user_id
prompt = _(
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
**kwargs) -> str:
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
if chat_model is None:
return problem_text
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
message_list = [*flat_map(history_message),
HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
response = chat_model.invoke(message_list)
padding_problem = problem_text
if response.content.__contains__("<data>") and response.content.__contains__('</data>'):
padding_problem_data = response.content[
response.content.index('<data>') + 6:response.content.index('</data>')]
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
padding_problem = padding_problem_data
elif len(response.content) > 0:
padding_problem = response.content
try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(padding_problem)
except Exception as e:
request_token = 0
response_token = 0
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
return padding_problem
def get_details(self, manage, **kwargs):
return {
'step_type': 'problem_padding',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
'message_tokens': self.context.get('message_tokens', 0),
'answer_tokens': self.context.get('answer_tokens', 0),
'cost': 0,
'padding_problem_text': self.context.get('padding_problem_text'),
'problem_text': self.context.get("step_args").get('problem_text'),
}

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:24
@desc:
"""

View File

@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_search_dataset_step.py
@date2024/1/9 18:10
@desc: 检索知识库
"""
import re
from abc import abstractmethod
from typing import List, Type
from django.core import validators
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.util.field_message import ErrMessage
class ISearchDatasetStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 原始问题文本
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
# 系统补全问题文本
padding_problem_text = serializers.CharField(required=False,
error_messages=ErrMessage.char(_("System completes question text")))
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list(_("Dataset id list")))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list(_("List of document ids to exclude")))
# 需要排除向量id
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list(_("List of exclusion vector ids")))
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer(_("Reference segment number")))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float(_("Similarity")))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message=_("The type only supports embedding|keywords|blend"), code=500)
], error_messages=ErrMessage.char(_("Retrieval Mode")))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
def _run(self, manage: PipelineManage):
paragraph_list = self.execute(**self.context['step_args'])
manage.context['paragraph_list'] = paragraph_list
self.context['paragraph_list'] = paragraph_list
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
:param dataset_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:param user_id 用户id
:return: 段落列表
"""
pass

View File

@ -0,0 +1,138 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_search_dataset_step.py
@date2024/1/10 10:33
@desc:
"""
import os
from typing import List, Dict
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework.utils.formatting import lazy_format
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception(_("Model does not exist"))
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
raise Exception(message)
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(dataset_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return dataset_list[0].embedding_mode_id
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return result
@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return (ParagraphPipelineModel.builder()
.add_paragraph(paragraph)
.add_similarity(find_embedding.get('similarity'))
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
.add_dataset_name(paragraph.get('dataset_name'))
.add_document_name(paragraph.get('document_name'))
.add_hit_handling_method(paragraph.get('hit_handling_method'))
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
.add_meta(paragraph.get('meta'))
.build())
@staticmethod
def get_similarity(paragraph, embedding_list: List):
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return find_embedding.get('comprehensive_score')
return 0
@staticmethod
def list_paragraph(embedding_list: List, vector):
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [row.get('id') for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
# 如果存在直接返回的则取直接返回段落
hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
(paragraph.get(
'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity(
paragraph, embedding_list) >= paragraph.get(
'directly_return_similarity'))]
if len(hit_handling_method_paragraph) > 0:
# 找到评分最高的
return [sorted(hit_handling_method_paragraph,
key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]]
return paragraph_list
def get_details(self, manage, **kwargs):
step_args = self.context['step_args']
return {
'step_type': 'search_step',
'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': self.context.get('model_name'),
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0
}

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""

View File

@ -0,0 +1,44 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file common.py
@date2024/12/11 17:57
@desc:
"""
class Answer:
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
reasoning_content):
self.view_type = view_type
self.content = content
self.reasoning_content = reasoning_content
self.runtime_node_id = runtime_node_id
self.chat_record_id = chat_record_id
self.child_node = child_node
self.real_node_id = real_node_id
def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id,
'child_node': self.child_node,
'reasoning_content': self.reasoning_content,
'real_node_id': self.real_node_id}
class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []
def add_chunk(self, chunk):
self.chunk_list.append(chunk)
def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200
def is_end(self):
return self.status == 200

View File

@ -0,0 +1,451 @@
{
"nodes": [
{
"id": "base-node",
"type": "base-node",
"x": 360,
"y": 2810,
"properties": {
"config": {
},
"height": 825.6,
"stepName": "基本信息",
"node_data": {
"desc": "",
"name": "maxkbapplication",
"prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
},
"input_field_list": [
]
}
},
{
"id": "start-node",
"type": "start-node",
"x": 430,
"y": 3660,
"properties": {
"config": {
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
},
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"height": 276,
"stepName": "开始",
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
}
},
{
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"type": "search-dataset-node",
"x": 840,
"y": 3210,
"properties": {
"config": {
"fields": [
{
"label": "检索结果的分段列表",
"value": "paragraph_list"
},
{
"label": "满足直接回答的分段列表",
"value": "is_hit_handling_method_list"
},
{
"label": "检索结果",
"value": "data"
},
{
"label": "满足直接回答的分段内容",
"value": "directly_return"
}
]
},
"height": 794,
"stepName": "知识库检索",
"node_data": {
"dataset_id_list": [
],
"dataset_setting": {
"top_n": 3,
"similarity": 0.6,
"search_mode": "embedding",
"max_paragraph_char_number": 5000
},
"question_reference_address": [
"start-node",
"question"
],
"source_dataset_id_list": [
]
}
}
},
{
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"type": "condition-node",
"x": 1490,
"y": 3210,
"properties": {
"width": 600,
"config": {
"fields": [
{
"label": "分支名称",
"value": "branch_name"
}
]
},
"height": 543.675,
"stepName": "判断器",
"node_data": {
"branch": [
{
"id": "1009",
"type": "IF",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"is_hit_handling_method_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "4908",
"type": "ELSE IF 1",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"paragraph_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "161",
"type": "ELSE",
"condition": "and",
"conditions": [
]
}
]
},
"branch_condition_list": [
{
"index": 0,
"height": 121.225,
"id": "1009"
},
{
"index": 1,
"height": 121.225,
"id": "4908"
},
{
"index": 2,
"height": 44,
"id": "161"
}
]
}
},
{
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"type": "reply-node",
"x": 2170,
"y": 2480,
"properties": {
"config": {
"fields": [
{
"label": "内容",
"value": "answer"
}
]
},
"height": 378,
"stepName": "指定回复",
"node_data": {
"fields": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"directly_return"
],
"content": "",
"reply_type": "referencing",
"is_result": true
}
}
},
{
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"type": "ai-chat-node",
"x": 2160,
"y": 3200,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 对话",
"node_data": {
"prompt": "已知信息:\n{{知识库检索.data}}\n问题\n{{开始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
},
{
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"type": "ai-chat-node",
"x": 2160,
"y": 3970,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 对话1",
"node_data": {
"prompt": "{{开始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
}
],
"edges": [
{
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
"type": "app-edge",
"sourceNodeId": "start-node",
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"startPoint": {
"x": 590,
"y": 3660
},
"endPoint": {
"x": 680,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 590,
"y": 3660
},
{
"x": 700,
"y": 3660
},
{
"x": 570,
"y": 3210
},
{
"x": 680,
"y": 3210
}
],
"sourceAnchorId": "start-node_right",
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
},
{
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
"type": "app-edge",
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"startPoint": {
"x": 1000,
"y": 3210
},
"endPoint": {
"x": 1200,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 1000,
"y": 3210
},
{
"x": 1110,
"y": 3210
},
{
"x": 1090,
"y": 3210
},
{
"x": 1200,
"y": 3210
}
],
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
},
{
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"startPoint": {
"x": 1780,
"y": 3073.775
},
"endPoint": {
"x": 2010,
"y": 2480
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3073.775
},
{
"x": 1890,
"y": 3073.775
},
{
"x": 1900,
"y": 2480
},
{
"x": 2010,
"y": 2480
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
},
{
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"startPoint": {
"x": 1780,
"y": 3203
},
"endPoint": {
"x": 2000,
"y": 3200
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3203
},
{
"x": 1890,
"y": 3203
},
{
"x": 1890,
"y": 3200
},
{
"x": 2000,
"y": 3200
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
},
{
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"startPoint": {
"x": 1780,
"y": 3293.6124999999997
},
"endPoint": {
"x": 2000,
"y": 3970
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3970
},
{
"x": 2000,
"y": 3970
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
}
]
}

View File

@ -0,0 +1,451 @@
{
"nodes": [
{
"id": "base-node",
"type": "base-node",
"x": 360,
"y": 2810,
"properties": {
"config": {
},
"height": 825.6,
"stepName": "Base",
"node_data": {
"desc": "",
"name": "maxkbapplication",
"prologue": "Hello, I am the MaxKB assistant. You can ask me about MaxKB usage issues.\n-What are the main functions of MaxKB?\n-What major language models does MaxKB support?\n-What document types does MaxKB support?"
},
"input_field_list": [
]
}
},
{
"id": "start-node",
"type": "start-node",
"x": 430,
"y": 3660,
"properties": {
"config": {
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
},
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"height": 276,
"stepName": "Start",
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
}
},
{
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"type": "search-dataset-node",
"x": 840,
"y": 3210,
"properties": {
"config": {
"fields": [
{
"label": "检索结果的分段列表",
"value": "paragraph_list"
},
{
"label": "满足直接回答的分段列表",
"value": "is_hit_handling_method_list"
},
{
"label": "检索结果",
"value": "data"
},
{
"label": "满足直接回答的分段内容",
"value": "directly_return"
}
]
},
"height": 794,
"stepName": "Knowledge Search",
"node_data": {
"dataset_id_list": [
],
"dataset_setting": {
"top_n": 3,
"similarity": 0.6,
"search_mode": "embedding",
"max_paragraph_char_number": 5000
},
"question_reference_address": [
"start-node",
"question"
],
"source_dataset_id_list": [
]
}
}
},
{
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"type": "condition-node",
"x": 1490,
"y": 3210,
"properties": {
"width": 600,
"config": {
"fields": [
{
"label": "分支名称",
"value": "branch_name"
}
]
},
"height": 543.675,
"stepName": "Conditional Branch",
"node_data": {
"branch": [
{
"id": "1009",
"type": "IF",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"is_hit_handling_method_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "4908",
"type": "ELSE IF 1",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"paragraph_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "161",
"type": "ELSE",
"condition": "and",
"conditions": [
]
}
]
},
"branch_condition_list": [
{
"index": 0,
"height": 121.225,
"id": "1009"
},
{
"index": 1,
"height": 121.225,
"id": "4908"
},
{
"index": 2,
"height": 44,
"id": "161"
}
]
}
},
{
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"type": "reply-node",
"x": 2170,
"y": 2480,
"properties": {
"config": {
"fields": [
{
"label": "内容",
"value": "answer"
}
]
},
"height": 378,
"stepName": "Specified Reply",
"node_data": {
"fields": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"directly_return"
],
"content": "",
"reply_type": "referencing",
"is_result": true
}
}
},
{
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"type": "ai-chat-node",
"x": 2160,
"y": 3200,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI Chat",
"node_data": {
"prompt": "Known information:\n{{Knowledge Search.data}}\nQuestion:\n{{Start.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
},
{
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"type": "ai-chat-node",
"x": 2160,
"y": 3970,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI Chat1",
"node_data": {
"prompt": "{{Start.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
}
],
"edges": [
{
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
"type": "app-edge",
"sourceNodeId": "start-node",
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"startPoint": {
"x": 590,
"y": 3660
},
"endPoint": {
"x": 680,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 590,
"y": 3660
},
{
"x": 700,
"y": 3660
},
{
"x": 570,
"y": 3210
},
{
"x": 680,
"y": 3210
}
],
"sourceAnchorId": "start-node_right",
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
},
{
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
"type": "app-edge",
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"startPoint": {
"x": 1000,
"y": 3210
},
"endPoint": {
"x": 1200,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 1000,
"y": 3210
},
{
"x": 1110,
"y": 3210
},
{
"x": 1090,
"y": 3210
},
{
"x": 1200,
"y": 3210
}
],
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
},
{
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"startPoint": {
"x": 1780,
"y": 3073.775
},
"endPoint": {
"x": 2010,
"y": 2480
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3073.775
},
{
"x": 1890,
"y": 3073.775
},
{
"x": 1900,
"y": 2480
},
{
"x": 2010,
"y": 2480
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
},
{
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"startPoint": {
"x": 1780,
"y": 3203
},
"endPoint": {
"x": 2000,
"y": 3200
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3203
},
{
"x": 1890,
"y": 3203
},
{
"x": 1890,
"y": 3200
},
{
"x": 2000,
"y": 3200
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
},
{
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"startPoint": {
"x": 1780,
"y": 3293.6124999999997
},
"endPoint": {
"x": 2000,
"y": 3970
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3970
},
{
"x": 2000,
"y": 3970
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
}
]
}

View File

@ -0,0 +1,451 @@
{
"nodes": [
{
"id": "base-node",
"type": "base-node",
"x": 360,
"y": 2810,
"properties": {
"config": {
},
"height": 825.6,
"stepName": "基本信息",
"node_data": {
"desc": "",
"name": "maxkbapplication",
"prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
},
"input_field_list": [
]
}
},
{
"id": "start-node",
"type": "start-node",
"x": 430,
"y": 3660,
"properties": {
"config": {
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
},
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"height": 276,
"stepName": "开始",
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
}
},
{
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"type": "search-dataset-node",
"x": 840,
"y": 3210,
"properties": {
"config": {
"fields": [
{
"label": "检索结果的分段列表",
"value": "paragraph_list"
},
{
"label": "满足直接回答的分段列表",
"value": "is_hit_handling_method_list"
},
{
"label": "检索结果",
"value": "data"
},
{
"label": "满足直接回答的分段内容",
"value": "directly_return"
}
]
},
"height": 794,
"stepName": "知识库检索",
"node_data": {
"dataset_id_list": [
],
"dataset_setting": {
"top_n": 3,
"similarity": 0.6,
"search_mode": "embedding",
"max_paragraph_char_number": 5000
},
"question_reference_address": [
"start-node",
"question"
],
"source_dataset_id_list": [
]
}
}
},
{
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"type": "condition-node",
"x": 1490,
"y": 3210,
"properties": {
"width": 600,
"config": {
"fields": [
{
"label": "分支名称",
"value": "branch_name"
}
]
},
"height": 543.675,
"stepName": "判断器",
"node_data": {
"branch": [
{
"id": "1009",
"type": "IF",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"is_hit_handling_method_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "4908",
"type": "ELSE IF 1",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"paragraph_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "161",
"type": "ELSE",
"condition": "and",
"conditions": [
]
}
]
},
"branch_condition_list": [
{
"index": 0,
"height": 121.225,
"id": "1009"
},
{
"index": 1,
"height": 121.225,
"id": "4908"
},
{
"index": 2,
"height": 44,
"id": "161"
}
]
}
},
{
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"type": "reply-node",
"x": 2170,
"y": 2480,
"properties": {
"config": {
"fields": [
{
"label": "内容",
"value": "answer"
}
]
},
"height": 378,
"stepName": "指定回复",
"node_data": {
"fields": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"directly_return"
],
"content": "",
"reply_type": "referencing",
"is_result": true
}
}
},
{
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"type": "ai-chat-node",
"x": 2160,
"y": 3200,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 对话",
"node_data": {
"prompt": "已知信息:\n{{知识库检索.data}}\n问题\n{{开始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
},
{
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"type": "ai-chat-node",
"x": 2160,
"y": 3970,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 对话1",
"node_data": {
"prompt": "{{开始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
}
],
"edges": [
{
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
"type": "app-edge",
"sourceNodeId": "start-node",
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"startPoint": {
"x": 590,
"y": 3660
},
"endPoint": {
"x": 680,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 590,
"y": 3660
},
{
"x": 700,
"y": 3660
},
{
"x": 570,
"y": 3210
},
{
"x": 680,
"y": 3210
}
],
"sourceAnchorId": "start-node_right",
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
},
{
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
"type": "app-edge",
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"startPoint": {
"x": 1000,
"y": 3210
},
"endPoint": {
"x": 1200,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 1000,
"y": 3210
},
{
"x": 1110,
"y": 3210
},
{
"x": 1090,
"y": 3210
},
{
"x": 1200,
"y": 3210
}
],
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
},
{
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"startPoint": {
"x": 1780,
"y": 3073.775
},
"endPoint": {
"x": 2010,
"y": 2480
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3073.775
},
{
"x": 1890,
"y": 3073.775
},
{
"x": 1900,
"y": 2480
},
{
"x": 2010,
"y": 2480
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
},
{
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"startPoint": {
"x": 1780,
"y": 3203
},
"endPoint": {
"x": 2000,
"y": 3200
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3203
},
{
"x": 1890,
"y": 3203
},
{
"x": 1890,
"y": 3200
},
{
"x": 2000,
"y": 3200
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
},
{
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"startPoint": {
"x": 1780,
"y": 3293.6124999999997
},
"endPoint": {
"x": 2000,
"y": 3970
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3970
},
{
"x": 2000,
"y": 3970
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
}
]
}

View File

@ -0,0 +1,451 @@
{
"nodes": [
{
"id": "base-node",
"type": "base-node",
"x": 360,
"y": 2810,
"properties": {
"config": {
},
"height": 825.6,
"stepName": "基本資訊",
"node_data": {
"desc": "",
"name": "maxkbapplication",
"prologue": "您好我是MaxKB小助手您可以向我提出MaxKB使用問題。\n- MaxKB主要功能有什麼\n- MaxKB支持哪些大語言模型\n- MaxKB支持哪些文檔類型"
},
"input_field_list": [
]
}
},
{
"id": "start-node",
"type": "start-node",
"x": 430,
"y": 3660,
"properties": {
"config": {
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
},
"fields": [
{
"label": "用户问题",
"value": "question"
}
],
"height": 276,
"stepName": "開始",
"globalFields": [
{
"label": "当前时间",
"value": "time"
}
]
}
},
{
"id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"type": "search-dataset-node",
"x": 840,
"y": 3210,
"properties": {
"config": {
"fields": [
{
"label": "检索结果的分段列表",
"value": "paragraph_list"
},
{
"label": "满足直接回答的分段列表",
"value": "is_hit_handling_method_list"
},
{
"label": "检索结果",
"value": "data"
},
{
"label": "满足直接回答的分段内容",
"value": "directly_return"
}
]
},
"height": 794,
"stepName": "知識庫檢索",
"node_data": {
"dataset_id_list": [
],
"dataset_setting": {
"top_n": 3,
"similarity": 0.6,
"search_mode": "embedding",
"max_paragraph_char_number": 5000
},
"question_reference_address": [
"start-node",
"question"
],
"source_dataset_id_list": [
]
}
}
},
{
"id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"type": "condition-node",
"x": 1490,
"y": 3210,
"properties": {
"width": 600,
"config": {
"fields": [
{
"label": "分支名称",
"value": "branch_name"
}
]
},
"height": 543.675,
"stepName": "判斷器",
"node_data": {
"branch": [
{
"id": "1009",
"type": "IF",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"is_hit_handling_method_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "4908",
"type": "ELSE IF 1",
"condition": "and",
"conditions": [
{
"field": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"paragraph_list"
],
"value": "1",
"compare": "len_ge"
}
]
},
{
"id": "161",
"type": "ELSE",
"condition": "and",
"conditions": [
]
}
]
},
"branch_condition_list": [
{
"index": 0,
"height": 121.225,
"id": "1009"
},
{
"index": 1,
"height": 121.225,
"id": "4908"
},
{
"index": 2,
"height": 44,
"id": "161"
}
]
}
},
{
"id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"type": "reply-node",
"x": 2170,
"y": 2480,
"properties": {
"config": {
"fields": [
{
"label": "内容",
"value": "answer"
}
]
},
"height": 378,
"stepName": "指定回覆",
"node_data": {
"fields": [
"b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"directly_return"
],
"content": "",
"reply_type": "referencing",
"is_result": true
}
}
},
{
"id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"type": "ai-chat-node",
"x": 2160,
"y": 3200,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 對話",
"node_data": {
"prompt": "已知資訊:\n{{知識庫檢索.data}}\n問題\n{{開始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
},
{
"id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"type": "ai-chat-node",
"x": 2160,
"y": 3970,
"properties": {
"config": {
"fields": [
{
"label": "AI 回答内容",
"value": "answer"
}
]
},
"height": 763,
"stepName": "AI 對話1",
"node_data": {
"prompt": "{{開始.question}}",
"system": "",
"model_id": "",
"dialogue_number": 0,
"is_result": true
}
}
}
],
"edges": [
{
"id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
"type": "app-edge",
"sourceNodeId": "start-node",
"targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"startPoint": {
"x": 590,
"y": 3660
},
"endPoint": {
"x": 680,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 590,
"y": 3660
},
{
"x": 700,
"y": 3660
},
{
"x": 570,
"y": 3210
},
{
"x": 680,
"y": 3210
}
],
"sourceAnchorId": "start-node_right",
"targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
},
{
"id": "35cb86dd-f328-429e-a973-12fd7218b696",
"type": "app-edge",
"sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
"targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"startPoint": {
"x": 1000,
"y": 3210
},
"endPoint": {
"x": 1200,
"y": 3210
},
"properties": {
},
"pointsList": [
{
"x": 1000,
"y": 3210
},
{
"x": 1110,
"y": 3210
},
{
"x": 1090,
"y": 3210
},
{
"x": 1200,
"y": 3210
}
],
"sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
"targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
},
{
"id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
"startPoint": {
"x": 1780,
"y": 3073.775
},
"endPoint": {
"x": 2010,
"y": 2480
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3073.775
},
{
"x": 1890,
"y": 3073.775
},
{
"x": 1900,
"y": 2480
},
{
"x": 2010,
"y": 2480
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
"targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
},
{
"id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
"startPoint": {
"x": 1780,
"y": 3203
},
"endPoint": {
"x": 2000,
"y": 3200
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3203
},
{
"x": 1890,
"y": 3203
},
{
"x": 1890,
"y": 3200
},
{
"x": 2000,
"y": 3200
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
"targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
},
{
"id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
"type": "app-edge",
"sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
"targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
"startPoint": {
"x": 1780,
"y": 3293.6124999999997
},
"endPoint": {
"x": 2000,
"y": 3970
},
"properties": {
},
"pointsList": [
{
"x": 1780,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3293.6124999999997
},
{
"x": 1890,
"y": 3970
},
{
"x": 2000,
"y": 3970
}
],
"sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
"targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
}
]
}

View File

@ -0,0 +1,256 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_step_node.py
@date2024/6/3 14:57
@desc:
"""
import time
import uuid
from abc import abstractmethod
from hashlib import sha1
from typing import Type, Dict, List
from django.core import cache
from django.db.models import QuerySet
from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
chat_cache = cache.caches['chat_cache']
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
answer = step_variable['answer']
yield answer
node.answer_text = answer
if global_variable is not None:
for key in global_variable:
workflow.context[key] = global_variable[key]
node.context['run_time'] = time.time() - node.context['start_time']
def is_interrupt(node, step_variable: Dict, global_variable: Dict):
return node.type == 'form-node' and not node.context.get('is_submit', False)
class WorkFlowPostHandler:
def __init__(self, chat_info, client_id, client_type):
self.chat_info = chat_info
self.client_id = client_id
self.client_type = client_type
def handler(self, chat_id,
chat_record_id,
answer,
workflow):
question = workflow.params['question']
details = workflow.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = workflow.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
if workflow.chat_record is not None:
chat_record = workflow.chat_record
chat_record.answer_text = answer_text
chat_record.details = details
chat_record.message_tokens = message_tokens
chat_record.answer_tokens = answer_tokens
chat_record.answer_text_list = answer_text_list
chat_record.run_time = time.time() - workflow.context['start_time']
else:
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=question,
answer_text=answer_text,
details=details,
message_tokens=message_tokens,
answer_tokens=answer_tokens,
answer_text_list=answer_text_list,
run_time=time.time() - workflow.context['start_time'],
index=0)
asker = workflow.context.get('asker', None)
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
# 重新设置缓存
chat_cache.set(chat_id,
self.chat_info, timeout=60 * 30)
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_public_access_client = (QuerySet(ApplicationPublicAccessClient)
.filter(client_id=self.client_id,
application_id=self.chat_info.application.id).first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict,
_write_context=write_context, _is_interrupt=is_interrupt):
self._write_context = _write_context
self.node_variable = node_variable
self.workflow_variable = workflow_variable
self._is_interrupt = _is_interrupt
def write_context(self, node, workflow):
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
def is_assertion_result(self):
return 'branch_id' in self.node_variable
def is_interrupt_exec(self, current_node):
"""
是否中断执行
@param current_node:
@return:
"""
return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
class ReferenceAddressSerializer(serializers.Serializer):
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
fields = serializers.ListField(
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
error_messages=ErrMessage.list("节点字段数组"))
class FlowParamsSerializer(serializers.Serializer):
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出"))
client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
class INode:
view_type = 'many_view'
@abstractmethod
def save_context(self, details, workflow_manage):
pass
def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
# 当前步骤上下文,用于存储当前步骤信息
self.status = 200
self.err_message = ''
self.node = node
self.node_params = get_node_params(node)
self.workflow_params = workflow_params
self.workflow_manage = workflow_manage
self.node_params_serializer = None
self.flow_params_serializer = None
self.context = {}
self.answer_text = None
self.id = node.id
if up_node_id_list is None:
up_node_id_list = []
self.up_node_id_list = up_node_id_list
self.node_chunk = NodeChunk()
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
"utf-8")).hexdigest()
def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class()
node_params_serializer_class = self.get_node_params_serializer_class()
if flow_params_serializer_class is not None and flow_params is not None:
self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
self.flow_params_serializer.is_valid(raise_exception=True)
if node_params_serializer_class is not None:
self.node_params_serializer = node_params_serializer_class(data=node_params)
self.node_params_serializer.is_valid(raise_exception=True)
if self.node.properties.get('status', 200) != 200:
raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用'))
def get_reference_field(self, fields: List[str]):
return self.get_field(self.context, fields)
@staticmethod
def get_field(obj, fields: List[str]):
for field in fields:
value = obj.get(field)
if value is None:
return None
else:
obj = value
return obj
@abstractmethod
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
pass
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
return FlowParamsSerializer
def get_write_error_context(self, e):
self.status = 500
self.answer_text = str(e)
self.err_message = str(e)
self.context['run_time'] = time.time() - self.context['start_time']
def write_error_context(answer, status=200):
pass
return write_error_context
def run(self) -> NodeResult:
"""
:return: 执行结果
"""
start_time = time.time()
self.context['start_time'] = start_time
result = self._run()
self.context['run_time'] = time.time() - start_time
return result
def _run(self):
result = self.execute()
return result
def execute(self, **kwargs) -> NodeResult:
pass
def get_details(self, index: int, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return {}

View File

@ -0,0 +1,42 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .ai_chat_step_node import *
from .application_node import BaseApplicationNode
from .condition_node import *
from .direct_reply_node import *
from .form_node import *
from .function_lib_node import *
from .function_node import *
from .question_node import *
from .reranker_node import *
from .document_extract_node import *
from .image_understand_step_node import *
from .image_generate_step_node import *
from .search_dataset_node import *
from .speech_to_text_step_node import BaseSpeechToTextNode
from .start_node import *
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
from .variable_assign_node import BaseVariableAssignNode
from .mcp_node import BaseMcpNode
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode,
BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
def get_node(node_type):
find_list = [node for node in node_list if node.type == node_type]
if len(find_list) > 0:
return find_list[0]
return None

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:29
@desc:
"""
from .impl import *

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_node.py
@date2024/6/4 13:58
@desc:
"""
from typing import Type
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class ChatNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Role Setting")))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
_("Number of multi-round conversations")))
is_result = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean(_('Whether to return content')))
model_params_setting = serializers.DictField(required=False,
error_messages=ErrMessage.dict(_("Model parameter settings")))
model_setting = serializers.DictField(required=False,
error_messages=ErrMessage.dict('Model settings'))
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Context Type")))
mcp_enable = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean(_("Whether to enable MCP")))
mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server")))
class IChatNode(INode):
type = 'ai-chat-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ChatNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
model_params_setting=None,
dialogue_type=None,
model_setting=None,
mcp_enable=False,
mcp_servers=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:34
@desc:
"""
from .base_chat_node import BaseChatNode

View File

@ -0,0 +1,288 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_question_node.py
@date2024/6/4 14:30
@desc:
"""
import asyncio
import json
import re
import time
from functools import reduce
from types import AsyncGeneratorType
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.prebuilt import create_react_agent
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id
tool_message_template = """
<details>
<summary>
<strong>Called MCP Tool: <em>%s</em></strong>
</summary>
```json
%s
```
</details>
"""
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
node.context['reasoning_content'] = reasoning_content
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
node.answer_text = answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
reasoning_content = ''
model_setting = node.context.get('model_setting',
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
'reasoning_content_start': '<think>'})
reasoning = Reasoning(model_setting.get('reasoning_content_start', '<think>'),
model_setting.get('reasoning_content_end', '</think>'))
response_reasoning_content = False
for chunk in response:
reasoning_chunk = reasoning.get_reasoning_content(chunk)
content_chunk = reasoning_chunk.get('content')
if 'reasoning_content' in chunk.additional_kwargs:
response_reasoning_content = True
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
answer += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
reasoning_content += reasoning_content_chunk
yield {'content': content_chunk,
'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
False) else ''}
reasoning_chunk = reasoning.get_end_reasoning_content()
answer += reasoning_chunk.get('content')
reasoning_content_chunk = ""
if not response_reasoning_content:
reasoning_content_chunk = reasoning_chunk.get(
'reasoning_content')
yield {'content': reasoning_chunk.get('content'),
'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
False) else ''}
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
async def _yield_mcp_response(chat_model, message_list, mcp_servers):
async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
agent = create_react_agent(chat_model, client.get_tools())
response = agent.astream({"messages": message_list}, stream_mode='messages')
async for chunk in response:
if isinstance(chunk[0], ToolMessage):
content = tool_message_template % (chunk[0].name, chunk[0].content)
chunk[0].content = content
yield chunk[0]
if isinstance(chunk[0], AIMessageChunk):
yield chunk[0]
def mcp_response_generator(chat_model, message_list, mcp_servers):
loop = asyncio.new_event_loop()
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
while True:
try:
chunk = loop.run_until_complete(anext_async(async_gen))
yield chunk
except StopAsyncIteration:
break
except Exception as e:
print(f'exception: {e}')
finally:
loop.close()
async def anext_async(agen):
return await agen.__anext__()
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
model_setting = node.context.get('model_setting',
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
'reasoning_content_start': '<think>'})
reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end'))
reasoning_result = reasoning.get_reasoning_content(response)
reasoning_result_end = reasoning.get_end_reasoning_content()
content = reasoning_result.get('content') + reasoning_result_end.get('content')
if 'reasoning_content' in response.response_metadata:
reasoning_content = response.response_metadata.get('reasoning_content', '')
else:
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
def get_node_message(chat_record, runtime_node_id):
node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
if node_details is None:
return []
return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]
def get_workflow_message(chat_record):
return [chat_record.get_human_message(), chat_record.get_ai_message()]
def get_message(chat_record, dialogue_type, runtime_node_id):
return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
chat_record)
class BaseChatNode(IChatNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.context['reasoning_content'] = details.get('reasoning_content')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
dialogue_type=None,
model_setting=None,
mcp_enable=False,
mcp_servers=None,
**kwargs) -> NodeResult:
if dialogue_type is None:
dialogue_type = 'WORKFLOW'
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
if model_setting is None:
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
'reasoning_content_start': '<think>'}
self.context['model_setting'] = model_setting
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
system = self.workflow_manage.generate_prompt(system)
self.context['system'] = system
message_list = self.generate_message_list(system, prompt, history_message)
self.context['message_list'] = message_list
if mcp_enable and mcp_servers is not None:
r = mcp_response_generator(chat_model, message_list, mcp_servers)
return NodeResult(
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
get_message(history_chat_record[index], dialogue_type, runtime_node_id)
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
for message in history_message:
if isinstance(message.content, str):
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, system: str, prompt: str, history_message):
if system is not None and len(system) > 0:
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
else:
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'system': self.context.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'reasoning_content': self.context.get('reasoning_content'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,2 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,86 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class ApplicationNodeSerializer(serializers.Serializer):
application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID")))
question_reference_address = serializers.ListField(required=True,
error_messages=ErrMessage.list(_("User Questions")))
api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields")))
user_input_field_list = serializers.ListField(required=False,
error_messages=ErrMessage.uuid(_("User Input Fields")))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio")))
child_node = serializers.DictField(required=False, allow_null=True,
error_messages=ErrMessage.dict(_("Child Nodes")))
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
class IApplicationNode(INode):
type = 'application-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ApplicationNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
kwargs = {}
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(value,
api_input_field['value'][
1:]) if value != '' else ''
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(value,
user_input_field['value'][
1:]) if value != '' else ''
# 判断是否包含这个属性
app_document_list = self.node_params_serializer.data.get('document_list', [])
if app_document_list and len(app_document_list) > 0:
app_document_list = self.workflow_manage.get_reference_field(
app_document_list[0],
app_document_list[1:])
for document in app_document_list:
if 'file_id' not in document:
raise ValueError(
_("Parameter value error: The uploaded document lacks file_id, and the document upload fails"))
app_image_list = self.node_params_serializer.data.get('image_list', [])
if app_image_list and len(app_image_list) > 0:
app_image_list = self.workflow_manage.get_reference_field(
app_image_list[0],
app_image_list[1:])
for image in app_image_list:
if 'file_id' not in image:
raise ValueError(
_("Parameter value error: The uploaded image lacks file_id, and the image upload fails"))
app_audio_list = self.node_params_serializer.data.get('audio_list', [])
if app_audio_list and len(app_audio_list) > 0:
app_audio_list = self.workflow_manage.get_reference_field(
app_audio_list[0],
app_audio_list[1:])
for audio in app_audio_list:
if 'file_id' not in audio:
raise ValueError(
_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
app_document_list=app_document_list, app_image_list=app_image_list,
app_audio_list=app_audio_list,
message=str(question), **kwargs)
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,2 @@
# coding=utf-8
from .base_application_node import BaseApplicationNode

View File

@ -0,0 +1,267 @@
# coding=utf-8
import json
import re
import time
import uuid
from typing import Dict, List
from application.flow.common import Answer
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.application_node.i_application_node import IApplicationNode
from application.models import Chat
def string_to_uuid(input_str):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))
def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
return node_variable.get('is_interrupt_exec', False)
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
result = node_variable.get('result')
node.context['application_node_dict'] = node_variable.get('application_node_dict')
node.context['node_dict'] = node_variable.get('node_dict', {})
node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec')
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
node.context['answer'] = answer
node.context['result'] = answer
node.context['reasoning_content'] = reasoning_content
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
node.answer_text = answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
reasoning_content = ''
usage = {}
node_child_node = {}
application_node_dict = node.context.get('application_node_dict', {})
is_interrupt_exec = False
for chunk in response:
# 先把流转成字符串
response_content = chunk.decode('utf-8')[6:]
response_content = json.loads(response_content)
content = response_content.get('content', '')
runtime_node_id = response_content.get('runtime_node_id', '')
chat_record_id = response_content.get('chat_record_id', '')
child_node = response_content.get('child_node')
view_type = response_content.get('view_type')
node_type = response_content.get('node_type')
real_node_id = response_content.get('real_node_id')
node_is_end = response_content.get('node_is_end', False)
_reasoning_content = response_content.get('reasoning_content', '')
if node_type == 'form-node':
is_interrupt_exec = True
answer += content
reasoning_content += _reasoning_content
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
'child_node': child_node}
if real_node_id is not None:
application_node = application_node_dict.get(real_node_id, None)
if application_node is None:
application_node_dict[real_node_id] = {'content': content,
'runtime_node_id': runtime_node_id,
'chat_record_id': chat_record_id,
'child_node': child_node,
'index': len(application_node_dict),
'view_type': view_type,
'reasoning_content': _reasoning_content}
else:
application_node['content'] += content
application_node['reasoning_content'] += _reasoning_content
yield {'content': content,
'node_type': node_type,
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
'reasoning_content': _reasoning_content,
'child_node': child_node,
'real_node_id': real_node_id,
'node_is_end': node_is_end,
'view_type': view_type}
usage = response_content.get('usage', {})
node_variable['result'] = {'usage': usage}
node_variable['is_interrupt_exec'] = is_interrupt_exec
node_variable['child_node'] = node_child_node
node_variable['application_node_dict'] = application_node_dict
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result', {}).get('data', {})
node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'),
'prompt_tokens': response.get('prompt_tokens')}}
answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
reasoning_content = response.get('reasoning_content', '')
answer_list = response.get('answer_list', [])
node_variable['application_node_dict'] = {answer.get('real_node_id'): {**answer, 'index': index} for answer, index
in
zip(answer_list, range(len(answer_list)))}
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
try:
if application_node_dict is None:
return
for key in application_node_dict:
application_node = application_node_dict[key]
if application_node.get('runtime_node_id') == runtime_node_id:
content: str = application_node.get('content')
match = re.search('<form_rander>.*?</form_rander>', content)
if match:
form_setting_str = match.group().replace('<form_rander>', '').replace('</form_rander>', '')
form_setting = json.loads(form_setting_str)
form_setting['is_submit'] = True
form_setting['form_data'] = node_data
value = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
res = re.sub('<form_rander>.*?</form_rander>',
'${value}', content)
application_node['content'] = res.replace('${value}', value)
except Exception as e:
pass
class BaseApplicationNode(IApplicationNode):
def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
application_node_dict = self.context.get('application_node_dict')
if application_node_dict is None or len(application_node_dict) == 0:
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
self.context.get('child_node'), self.runtime_node_id, '')]
else:
return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
'chat_record_id': n.get('chat_record_id')
, 'child_node': n.get('child_node')}, n.get('real_node_id'),
n.get('reasoning_content', ''))
for n in
sorted(application_node_dict.values(), key=lambda item: item.get('index'))]
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['result'] = details.get('answer')
self.context['question'] = details.get('question')
self.context['type'] = details.get('type')
self.context['reasoning_content'] = details.get('reasoning_content')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
from application.serializers.chat_message_serializers import ChatMessageSerializer
# 生成嵌入应用的chat_id
current_chat_id = string_to_uuid(chat_id + application_id)
Chat.objects.get_or_create(id=current_chat_id, defaults={
'application_id': application_id,
'abstract': message[0:1024],
'client_id': client_id,
})
if app_document_list is None:
app_document_list = []
if app_image_list is None:
app_image_list = []
if app_audio_list is None:
app_audio_list = []
runtime_node_id = None
record_id = None
child_node_value = None
if child_node is not None:
runtime_node_id = child_node.get('runtime_node_id')
record_id = child_node.get('chat_record_id')
child_node_value = child_node.get('child_node')
application_node_dict = self.context.get('application_node_dict')
reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
response = ChatMessageSerializer(
data={'chat_id': current_chat_id, 'message': message,
're_chat': re_chat,
'stream': stream,
'application_id': application_id,
'client_id': client_id,
'client_type': client_type,
'document_list': app_document_list,
'image_list': app_image_list,
'audio_list': app_audio_list,
'runtime_node_id': runtime_node_id,
'chat_record_id': record_id,
'child_node': child_node_value,
'node_data': node_data,
'form_data': kwargs}).chat()
if response.status_code == 200:
if stream:
content_generator = response.streaming_content
return NodeResult({'result': content_generator, 'question': message}, {},
_write_context=write_context_stream, _is_interrupt=_is_interrupt_exec)
else:
data = json.loads(response.content)
return NodeResult({'result': data, 'question': message}, {},
_write_context=write_context, _is_interrupt=_is_interrupt_exec)
def get_details(self, index: int, **kwargs):
global_fields = []
for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
global_fields.append({
'label': api_input_field['variable'],
'key': api_input_field['variable'],
'value': self.workflow_manage.get_reference_field(
value,
api_input_field['value'][1:]
) if value != '' else ''
})
for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
global_fields.append({
'label': user_input_field['label'],
'key': user_input_field['field'],
'value': self.workflow_manage.get_reference_field(
value,
user_input_field['value'][1:]
) if value != '' else ''
})
return {
'name': self.node.properties.get('stepName'),
"index": index,
"info": self.node.properties.get('node_data'),
'run_time': self.context.get('run_time'),
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'reasoning_content': self.context.get('reasoning_content'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'global_fields': global_fields,
'document_list': self.workflow_manage.document_list,
'image_list': self.workflow_manage.image_list,
'audio_list': self.workflow_manage.audio_list,
'application_node_dict': self.context.get('application_node_dict')
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .impl import *

View File

@ -0,0 +1,30 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .contain_compare import *
from .equal_compare import *
from .ge_compare import *
from .gt_compare import *
from .is_not_null_compare import *
from .is_not_true import IsNotTrueCompare
from .is_null_compare import *
from .is_true import IsTrueCompare
from .le_compare import *
from .len_equal_compare import *
from .len_ge_compare import *
from .len_gt_compare import *
from .len_le_compare import *
from .len_lt_compare import *
from .lt_compare import *
from .not_contain_compare import *
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
IsNullCompare(),
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()]

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file compare.py
@date2024/6/7 14:37
@desc:
"""
from abc import abstractmethod
from typing import List
class Compare:
@abstractmethod
def support(self, node_id, fields: List[str], source_value, compare, target_value):
pass
@abstractmethod
def compare(self, source_value, compare, target_value):
pass

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file contain_compare.py
@date2024/6/11 10:02
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class ContainCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'contain':
return True
def compare(self, source_value, compare, target_value):
if isinstance(source_value, str):
return str(target_value) in source_value
return any([str(item) == str(target_value) for item in source_value])

View File

@ -0,0 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file equal_compare.py
@date2024/6/7 14:44
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class EqualCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'eq':
return True
def compare(self, source_value, compare, target_value):
return str(source_value) == str(target_value)

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class GECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'ge':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) >= float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class GTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'gt':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) > float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file is_not_null_compare.py
@date2024/6/28 10:45
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare import Compare
class IsNotNullCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'is_not_null':
return True
def compare(self, source_value, compare, target_value):
return source_value is not None and len(source_value) > 0

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file is_not_true.py
@date2025/4/7 13:44
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare import Compare
class IsNotTrueCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'is_not_true':
return True
def compare(self, source_value, compare, target_value):
try:
return source_value is False
except Exception as e:
return False

View File

@ -0,0 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file is_null_compare.py
@date2024/6/28 10:45
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare import Compare
class IsNullCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'is_null':
return True
def compare(self, source_value, compare, target_value):
return source_value is None or len(source_value) == 0

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file IsTrue.py
@date2025/4/7 13:38
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare import Compare
class IsTrueCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'is_true':
return True
def compare(self, source_value, compare, target_value):
try:
return source_value is True
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'le':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) <= float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file equal_compare.py
@date2024/6/7 14:44
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenEqualCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_eq':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) == int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenGECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_ge':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) >= int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenGTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_gt':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) > int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenLECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_le':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) <= int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenLTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_lt':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) < int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'lt':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) < float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file contain_compare.py
@date2024/6/11 10:02
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class NotContainCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'not_contain':
return True
def compare(self, source_value, compare, target_value):
if isinstance(source_value, str):
return str(target_value) not in source_value
return not any([str(item) == str(target_value) for item in source_value])

View File

@ -0,0 +1,39 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_condition_node.py
@date2024/6/7 9:54
@desc:
"""
from typing import Type
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.flow.i_step_node import INode
from common.util.field_message import ErrMessage
class ConditionSerializer(serializers.Serializer):
compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator")))
value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value")))
field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields")))
class ConditionBranchSerializer(serializers.Serializer):
id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id")))
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type")))
condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and")))
conditions = ConditionSerializer(many=True)
class ConditionNodeParamsSerializer(serializers.Serializer):
branch = ConditionBranchSerializer(many=True)
class IConditionNode(INode):
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ConditionNodeParamsSerializer
type = 'condition-node'

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_condition_node import BaseConditionNode

View File

@ -0,0 +1,62 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_condition_node.py
@date2024/6/7 11:29
@desc:
"""
from typing import List
from application.flow.i_step_node import NodeResult
from application.flow.step_node.condition_node.compare import compare_handle_list
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
class BaseConditionNode(IConditionNode):
def save_context(self, details, workflow_manage):
self.context['branch_id'] = details.get('branch_id')
self.context['branch_name'] = details.get('branch_name')
def execute(self, **kwargs) -> NodeResult:
branch_list = self.node_params_serializer.data['branch']
branch = self._execute(branch_list)
r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
return r
def _execute(self, branch_list: List):
for branch in branch_list:
if self.branch_assertion(branch):
return branch
def branch_assertion(self, branch):
condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
branch.get('conditions')]
condition = branch.get('condition')
return all(condition_list) if condition == 'and' else any(condition_list)
def assertion(self, field_list: List[str], compare: str, value):
try:
value = self.workflow_manage.generate_prompt(value)
except Exception as e:
pass
field_value = None
try:
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
except Exception as e:
pass
for compare_handler in compare_handle_list:
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
return compare_handler.compare(field_value, compare, value)
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'branch_id': self.context.get('branch_id'),
'branch_name': self.context.get('branch_name'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 17:50
@desc:
"""
from .impl import *

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_reply_node.py
@date2024/6/11 16:25
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class ReplyNodeParamsSerializer(serializers.Serializer):
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type")))
fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field")))
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Direct answer content")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if self.data.get('reply_type') == 'referencing':
if 'fields' not in self.data:
raise AppApiException(500, _("Reference field cannot be empty"))
if len(self.data.get('fields')) < 2:
raise AppApiException(500, _("Reference field error"))
else:
if 'content' not in self.data or self.data.get('content') is None:
raise AppApiException(500, _("Content cannot be empty"))
class IReplyNode(INode):
type = 'reply-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ReplyNodeParamsSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 17:49
@desc:
"""
from .base_reply_node import *

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_reply_node.py
@date2024/6/11 17:25
@desc:
"""
from typing import List
from application.flow.i_step_node import NodeResult
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
class BaseReplyNode(IReplyNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing':
result = self.get_reference_content(fields)
else:
result = self.generate_reply_content(content)
return NodeResult({'answer': result}, {})
def generate_reply_content(self, prompt):
return self.workflow_manage.generate_prompt(prompt)
def get_reference_content(self, fields: List[str]):
return str(self.workflow_manage.get_reference_field(
fields[0],
fields[1:]))
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'answer': self.context.get('answer'),
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1 @@
from .impl import *

View File

@ -0,0 +1,28 @@
# coding=utf-8
from typing import Type
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class DocumentExtractNodeSerializer(serializers.Serializer):
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
class IDocumentExtractNode(INode):
type = 'document-extract-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return DocumentExtractNodeSerializer
def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0],
self.node_params_serializer.data.get('document_list')[1:])
return self.execute(document=res, **self.flow_params_serializer.data)
def execute(self, document, chat_id, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1 @@
from .base_document_extract_node import BaseDocumentExtractNode

View File

@ -0,0 +1,94 @@
# coding=utf-8
import io
import mimetypes
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
from dataset.models import File
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
from dataset.serializers.file_serializers import FileSerializer
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)
# 获取文件大小
file_size = len(file_bytes)
# 创建 InMemoryUploadedFile 对象
uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file
splitter = '\n`-----------------------------------`\n'
class BaseDocumentExtractNode(IDocumentExtractNode):
def save_context(self, details, workflow_manage):
self.context['content'] = details.get('content')
def execute(self, document, chat_id, **kwargs):
get_buffer = FileBufferHandle().get_buffer
self.context['document_list'] = document
content = []
if document is None or not isinstance(document, list):
return NodeResult({'content': ''}, {})
application = self.workflow_manage.work_flow_post_handler.chat_info.application
# doc文件中的图片保存
def save_image(image_list):
for image in image_list:
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
'file_id': str(image.id)
}
file = bytes_to_uploaded_file(image.image, image.image_name)
FileSerializer(data={'file': file, 'meta': meta}).upload()
for doc in document:
file = QuerySet(File).filter(id=doc['file_id']).first()
buffer = io.BytesIO(file.get_byte().tobytes())
buffer.name = doc['name'] # this is the important line
for split_handle in (parse_table_handle_list + split_handles):
if split_handle.support(buffer, get_buffer):
# 回到文件头
buffer.seek(0)
file_content = split_handle.get_content(buffer, save_image)
content.append('### ' + doc['name'] + '\n' + file_content)
break
return NodeResult({'content': splitter.join(content)}, {})
def get_details(self, index: int, **kwargs):
content = self.context.get('content', '').split(splitter)
# 不保存content全部内容因为content内容可能会很大
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'content': [file_content[:500] for file_content in content],
'status': self.status,
'err_message': self.err_message,
'document_list': self.context.get('document_list')
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/11/4 14:48
@desc:
"""
from .impl import *

View File

@ -0,0 +1,35 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_form_node.py
@date2024/11/4 14:48
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class FormNodeParamsSerializer(serializers.Serializer):
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration")))
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content')))
form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
class IFormNode(INode):
type = 'form-node'
view_type = 'single_view'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return FormNodeParamsSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/11/4 14:49
@desc:
"""
from .base_form_node import BaseFormNode

View File

@ -0,0 +1,107 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_form_node.py
@date2024/11/4 14:52
@desc:
"""
import json
import time
from typing import Dict, List
from langchain_core.prompts import PromptTemplate
from application.flow.common import Answer
from application.flow.i_step_node import NodeResult
from application.flow.step_node.form_node.i_form_node import IFormNode
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
result = step_variable['result']
yield result
node.answer_text = result
node.context['run_time'] = time.time() - node.context['start_time']
class BaseFormNode(IFormNode):
def save_context(self, details, workflow_manage):
form_data = details.get('form_data', None)
self.context['result'] = details.get('result')
self.context['form_content_format'] = details.get('form_content_format')
self.context['form_field_list'] = details.get('form_field_list')
self.context['run_time'] = details.get('run_time')
self.context['start_time'] = details.get('start_time')
self.context['form_data'] = form_data
self.context['is_submit'] = details.get('is_submit')
if self.node_params.get('is_result', False):
self.answer_text = details.get('result')
if form_data is not None:
for key in form_data:
self.context[key] = form_data[key]
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
if form_data is not None:
self.context['is_submit'] = True
self.context['form_data'] = form_data
for key in form_data:
self.context[key] = form_data.get(key)
else:
self.context['is_submit'] = False
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
"is_submit": self.context.get("is_submit", False)}
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
context = self.workflow_manage.get_workflow_content()
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form, context=context)
return NodeResult(
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
_write_context=write_context)
def get_answer_list(self) -> List[Answer] | None:
form_content_format = self.context.get('form_content_format')
form_field_list = self.context.get('form_field_list')
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
'form_data': self.context.get('form_data', {}),
"is_submit": self.context.get("is_submit", False)}
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
context = self.workflow_manage.get_workflow_content()
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form, context=context)
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
self.runtime_node_id, '')]
def get_details(self, index: int, **kwargs):
form_content_format = self.context.get('form_content_format')
form_field_list = self.context.get('form_field_list')
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
'form_data': self.context.get('form_data', {}),
"is_submit": self.context.get("is_submit", False)}
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
context = self.workflow_manage.get_workflow_content()
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form, context=context)
return {
'name': self.node.properties.get('stepName'),
"index": index,
"result": value,
"form_content_format": self.context.get('form_content_format'),
"form_field_list": self.context.get('form_field_list'),
'form_data': self.context.get('form_data'),
'start_time': self.context.get('start_time'),
'is_submit': self.context.get('is_submit'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/8/8 17:45
@desc:
"""
from .impl import *

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_function_lib_node.py
@date2024/8/8 16:21
@desc:
"""
from typing import Type
from django.db.models import QuerySet
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.field.common import ObjectField
from common.util.field_message import ErrMessage
from function_lib.models.function import FunctionLib
from django.utils.translation import gettext_lazy as _
class InputField(serializers.Serializer):
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
class FunctionLibNodeParamsSerializer(serializers.Serializer):
function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID')))
input_field_list = InputField(required=True, many=True)
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first()
if f_lib is None:
raise Exception(_('The function has been deleted'))
class IFunctionLibNode(INode):
type = 'function-lib-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return FunctionLibNodeParamsSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/8/8 17:48
@desc:
"""
from .base_function_lib_node import BaseFunctionLibNodeNode

View File

@ -0,0 +1,150 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_function_lib_node.py
@date2024/8/8 17:49
@desc:
"""
import json
import time
from typing import Dict
from django.db.models import QuerySet
from django.utils.translation import gettext as _
from application.flow.i_step_node import NodeResult
from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode
from common.exception.app_exception import AppApiException
from common.util.function_code import FunctionExecutor
from common.util.rsa_util import rsa_long_decrypt
from function_lib.models.function import FunctionLib
from smartdoc.const import CONFIG
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
result = str(step_variable['result']) + '\n'
yield result
node.answer_text = result
node.context['run_time'] = time.time() - node.context['start_time']
def get_field_value(debug_field_list, name, is_required):
result = [field for field in debug_field_list if field.get('name') == name]
if len(result) > 0:
return result[-1]['value']
if is_required:
raise AppApiException(500, _('Field: {name} No value set').format(name=name))
return None
def valid_reference_value(_type, value, name):
if _type == 'int':
instance_type = int | float
elif _type == 'float':
instance_type = float | int
elif _type == 'dict':
instance_type = dict
elif _type == 'array':
instance_type = list
elif _type == 'string':
instance_type = str
else:
raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name,
_type=_type))
if not isinstance(value, instance_type):
raise Exception(
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
value=value))
def convert_value(name: str, value, _type, is_required, source, node):
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
return None
if not is_required and source == 'reference' and (value is None or len(value) == 0):
return None
if source == 'reference':
value = node.workflow_manage.get_reference_field(
value[0],
value[1:])
valid_reference_value(_type, value, name)
if _type == 'int':
return int(value)
if _type == 'float':
return float(value)
return value
try:
if _type == 'int':
return int(value)
if _type == 'float':
return float(value)
if _type == 'dict':
v = json.loads(value)
if isinstance(v, dict):
return v
raise Exception(_('type error'))
if _type == 'array':
v = json.loads(value)
if isinstance(v, list):
return v
raise Exception(_('type error'))
return value
except Exception as e:
raise Exception(
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
value=value))
def valid_function(function_lib, user_id):
if function_lib is None:
raise Exception(_('Function does not exist'))
if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id):
raise Exception(_('No permission to use this function {name}').format(name=function_lib.name))
if not function_lib.is_active:
raise Exception(_('Function {name} is unavailable').format(name=function_lib.name))
class BaseFunctionLibNodeNode(IFunctionLibNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
if self.node_params.get('is_result'):
self.answer_text = str(details.get('result'))
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
valid_function(function_lib, self.flow_params_serializer.data.get('user_id'))
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
field.get('is_required'),
field.get('source'), self)
for field in
[{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'),
), **field}
for field in
function_lib.input_field_list]}
self.context['params'] = params
# 合并初始化参数
if function_lib.init_params is not None:
all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params
else:
all_params = params
result = function_executor.exec_code(function_lib.code, all_params)
return NodeResult({'result': result}, {}, _write_context=write_context)
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
"result": self.context.get('result'),
"params": self.context.get('params'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/13 10:43
@desc:
"""
from .impl import *

View File

@ -0,0 +1,63 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_function_lib_node.py
@date2024/8/8 16:21
@desc:
"""
import re
from typing import Type
from django.core import validators
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.exception.app_exception import AppApiException
from common.field.common import ObjectField
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
from rest_framework.utils.formatting import lazy_format
class InputField(serializers.Serializer):
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required")))
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[
validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
message=_("The field only supports string|int|dict|array|float"), code=500)
])
source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[
validators.RegexValidator(regex=re.compile("^custom|reference$"),
message=_("The field only supports custom|reference"), code=500)
])
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
is_required = self.data.get('is_required')
if is_required and self.data.get('value') is None:
message = lazy_format(_('{field}, this field is required.'), field=self.data.get("name"))
raise AppApiException(500, message)
class FunctionNodeParamsSerializer(serializers.Serializer):
input_field_list = InputField(required=True, many=True)
code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IFunctionNode(INode):
type = 'function-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return FunctionNodeParamsSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, input_field_list, code, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/13 11:19
@desc:
"""
from .base_function_node import BaseFunctionNodeNode

View File

@ -0,0 +1,108 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_function_lib_node.py
@date2024/8/8 17:49
@desc:
"""
import json
import time
from typing import Dict
from application.flow.i_step_node import NodeResult
from application.flow.step_node.function_node.i_function_node import IFunctionNode
from common.exception.app_exception import AppApiException
from common.util.function_code import FunctionExecutor
from smartdoc.const import CONFIG
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
result = str(step_variable['result']) + '\n'
yield result
node.answer_text = result
node.context['run_time'] = time.time() - node.context['start_time']
def valid_reference_value(_type, value, name):
if _type == 'int':
instance_type = int | float
elif _type == 'float':
instance_type = float | int
elif _type == 'dict':
instance_type = dict
elif _type == 'array':
instance_type = list
elif _type == 'string':
instance_type = str
else:
raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型')
if not isinstance(value, instance_type):
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
def convert_value(name: str, value, _type, is_required, source, node):
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
return None
if source == 'reference':
value = node.workflow_manage.get_reference_field(
value[0],
value[1:])
valid_reference_value(_type, value, name)
if _type == 'int':
return int(value)
if _type == 'float':
return float(value)
return value
try:
if _type == 'int':
return int(value)
if _type == 'float':
return float(value)
if _type == 'dict':
v = json.loads(value)
if isinstance(v, dict):
return v
raise Exception("类型错误")
if _type == 'array':
v = json.loads(value)
if isinstance(v, list):
return v
raise Exception("类型错误")
return value
except Exception as e:
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
class BaseFunctionNodeNode(IFunctionNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
if self.node_params.get('is_result', False):
self.answer_text = str(details.get('result'))
def execute(self, input_field_list, code, **kwargs) -> NodeResult:
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
field.get('is_required'), field.get('source'), self)
for field in input_field_list}
result = function_executor.exec_code(code, params)
self.context['params'] = params
return NodeResult({'result': result}, {}, _write_context=write_context)
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
"result": self.context.get('result'),
"params": self.context.get('params'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,45 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class ImageGenerateNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)")))
negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")),
allow_null=True, allow_blank=True, )
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=False, default=0,
error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
dialogue_type = serializers.CharField(required=False, default='NODE',
error_messages=ErrMessage.char(_("Conversation storage type")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
model_params_setting = serializers.JSONField(required=False, default=dict,
error_messages=ErrMessage.json(_("Model parameter settings")))
class IImageGenerateNode(INode):
type = 'image-generate-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ImageGenerateNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .base_image_generate_node import BaseImageGenerateNode

View File

@ -0,0 +1,122 @@
# coding=utf-8
from functools import reduce
from typing import List
import requests
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.tools import get_model_instance_by_model_user_id
class BaseImageGenerateNode(IImageGenerateNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id,
**kwargs) -> NodeResult:
print(model_params_setting)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question
message_list = self.generate_message_list(question, history_message)
self.context['message_list'] = message_list
self.context['dialogue_type'] = dialogue_type
print(message_list)
image_urls = tti_model.generate_image(question, negative_prompt)
# 保存图片
file_urls = []
for image_url in image_urls:
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
answer = ' '.join([f"![Image]({path})" for path in file_urls])
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
'image': [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls],
'history_message': history_message, 'question': question}, {})
def generate_history_ai_message(self, chat_record):
for val in chat_record.details.values():
if self.node.id == val['node_id'] and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return chat_record.get_ai_message()
image_list = val['image_list']
return AIMessage(content=[
*[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
])
return chat_record.get_ai_message()
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_human_message(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
return HumanMessage(content=data['question'])
return HumanMessage(content=chat_record.problem_text)
def generate_prompt_question(self, prompt):
return self.workflow_manage.generate_prompt(prompt)
def generate_message_list(self, question: str, history_message):
return [
*history_message,
question
]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,46 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class ImageUnderstandNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Role Setting")))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
model_params_setting = serializers.JSONField(required=False, default=dict,
error_messages=ErrMessage.json(_("Model parameter settings")))
class IImageUnderstandNode(INode):
type = 'image-understand-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ImageUnderstandNodeSerializer
def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
self.node_params_serializer.data.get('image_list')[1:])
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .base_image_understand_node import BaseImageUnderstandNode

View File

@ -0,0 +1,224 @@
# coding=utf-8
import base64
import os
import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id
from imghdr import what
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
node.answer_text = answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
for chunk in response:
answer += chunk.content
yield chunk.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = response.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
def file_id_to_base64(file_id: str):
file = QuerySet(File).filter(id=file_id).first()
file_bytes = file.get_byte()
base64_image = base64.b64encode(file_bytes).decode("utf-8")
return [base64_image, what(None, file_bytes.tobytes())]
class BaseImageUnderstandNode(IImageUnderstandNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:
# 处理不正确的参数
if image is None or not isinstance(image, list):
image = []
print(model_params_setting)
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
# 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# 生成消息列表, 真实的history_message
message_list = self.generate_message_list(image_model, system, prompt,
self.get_history_message(history_chat_record, dialogue_number), image)
self.context['message_list'] = message_list
self.context['image_list'] = image
self.context['dialogue_type'] = dialogue_type
if stream:
r = image_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
else:
r = image_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
def get_history_message_for_details(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message_for_details(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_ai_message(self, chat_record):
for val in chat_record.details.values():
if self.node.id == val['node_id'] and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return chat_record.get_ai_message()
return AIMessage(content=val['answer'])
return chat_record.get_ai_message()
def generate_history_human_message_for_details(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id_list = [image.get('file_id') for image in image_list]
return HumanMessage(content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]
])
return HumanMessage(content=chat_record.problem_text)
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_human_message(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
base64_image in image_base64_list]
])
return HumanMessage(content=chat_record.problem_text)
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
if image is not None and len(image) > 0:
# 处理多张图片
images = []
for img in image:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
image_bytes = file.get_byte()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
image_format = what(None, image_bytes.tobytes())
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
messages = [HumanMessage(
content=[
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
*images
])]
else:
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
if system is not None and len(system) > 0:
return [
SystemMessage(self.workflow_manage.generate_prompt(system)),
*history_message,
*messages
]
else:
return [
*history_message,
*messages
]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,35 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class McpNodeSerializer(serializers.Serializer):
mcp_servers = serializers.JSONField(required=True,
error_messages=ErrMessage.char(_("Mcp servers")))
mcp_server = serializers.CharField(required=True,
error_messages=ErrMessage.char(_("Mcp server")))
mcp_tool = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Mcp tool")))
tool_params = serializers.DictField(required=True,
error_messages=ErrMessage.char(_("Tool parameters")))
class IMcpNode(INode):
type = 'mcp-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return McpNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .base_mcp_node import BaseMcpNode

View File

@ -0,0 +1,61 @@
# coding=utf-8
import asyncio
import json
from typing import List
from langchain_mcp_adapters.client import MultiServerMCPClient
from application.flow.i_step_node import NodeResult
from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode
class BaseMcpNode(IMcpNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
self.context['tool_params'] = details.get('tool_params')
self.context['mcp_tool'] = details.get('mcp_tool')
if self.node_params.get('is_result', False):
self.answer_text = details.get('result')
def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
servers = json.loads(mcp_servers)
params = json.loads(json.dumps(tool_params))
params = self.handle_variables(params)
async def call_tool(s, session, t, a):
async with MultiServerMCPClient(s) as client:
s = await client.sessions[session].call_tool(t, a)
return s
res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params))
return NodeResult(
{'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {})
def handle_variables(self, tool_params):
# 处理参数中的变量
for k, v in tool_params.items():
if type(v) == str:
tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k])
if type(v) == dict:
self.handle_variables(v)
if (type(v) == list) and (type(v[0]) == str):
tool_params[k] = self.get_reference_content(v)
return tool_params
def get_reference_content(self, fields: List[str]):
return str(self.workflow_manage.get_reference_field(
fields[0],
fields[1:]))
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'status': self.status,
'err_message': self.err_message,
'type': self.node.type,
'mcp_tool': self.context.get('mcp_tool'),
'tool_params': self.context.get('tool_params'),
'result': self.context.get('result'),
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:30
@desc:
"""
from .impl import *

View File

@ -0,0 +1,42 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_node.py
@date2024/6/4 13:58
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class QuestionNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Role Setting")))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer(_("Model parameter settings")))
class IQuestionNode(INode):
type = 'question-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return QuestionNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_question_node import BaseQuestionNode

View File

@ -0,0 +1,159 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_question_node.py
@date2024/6/4 14:30
@desc:
"""
import re
import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
node.answer_text = answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
for chunk in response:
answer += chunk.content
yield chunk.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = response.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
class BaseQuestionNode(IQuestionNode):
def save_context(self, details, workflow_manage):
self.context['run_time'] = details.get('run_time')
self.context['question'] = details.get('question')
self.context['answer'] = details.get('answer')
self.context['message_tokens'] = details.get('message_tokens')
self.context['answer_tokens'] = details.get('answer_tokens')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
**kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
system = self.workflow_manage.generate_prompt(system)
self.context['system'] = system
message_list = self.generate_message_list(system, prompt, history_message)
self.context['message_list'] = message_list
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
for message in history_message:
if isinstance(message.content, str):
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, system: str, prompt: str, history_message):
if system is None or len(system) == 0:
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
else:
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'system': self.context.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/4 11:37
@desc:
"""
from .impl import *

View File

@ -0,0 +1,60 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_reranker_node.py
@date2024/9/4 10:40
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class RerankerSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer(_("Reference segment number")))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
error_messages=ErrMessage.float(_("Reference segment number")))
max_paragraph_char_number = serializers.IntegerField(required=True,
error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
class RerankerStepNodeSerializer(serializers.Serializer):
reranker_setting = RerankerSettingSerializer(required=True)
question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.UUIDField(required=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IRerankerNode(INode):
type = 'reranker-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return RerankerStepNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
reranker_list = [self.workflow_manage.get_reference_field(
reference[0],
reference[1:]) for reference in
self.node_params_serializer.data.get('reranker_reference_list')]
return self.execute(**self.node_params_serializer.data, question=str(question),
reranker_list=reranker_list)
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/4 11:39
@desc:
"""
from .base_reranker_node import *

View File

@ -0,0 +1,106 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_reranker_node.py
@date2024/9/4 11:41
@desc:
"""
from typing import List
from langchain_core.documents import Document
from application.flow.i_step_node import NodeResult
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
from setting.models_provider.tools import get_model_instance_by_model_user_id
def merge_reranker_list(reranker_list, result=None):
if result is None:
result = []
for document in reranker_list:
if isinstance(document, list):
merge_reranker_list(document, result)
elif isinstance(document, dict):
content = document.get('title', '') + document.get('content', '')
title = document.get("title")
dataset_name = document.get("dataset_name")
document_name = document.get('document_name')
result.append(
Document(page_content=str(document) if len(content) == 0 else content,
metadata={'title': title, 'dataset_name': dataset_name, 'document_name': document_name}))
else:
result.append(Document(page_content=str(document), metadata={}))
return result
def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
use_len = 0
result = []
for index in range(len(document_list)):
document = document_list[index]
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
'relevance_score') < similarity:
break
content = document.page_content[0:max_paragraph_char_number - use_len]
use_len = use_len + len(content)
result.append({'page_content': content, 'metadata': document.metadata})
return result
def reset_result_list(result_list: List[Document], document_list: List[Document]):
r = []
document_list = document_list.copy()
for result in result_list:
filter_result_list = [document for document in document_list if document.page_content == result.page_content]
if len(filter_result_list) > 0:
item = filter_result_list[0]
document_list.remove(item)
r.append(Document(page_content=item.page_content,
metadata={**item.metadata, 'relevance_score': result.metadata.get('relevance_score')}))
else:
r.append(result)
return r
class BaseRerankerNode(IRerankerNode):
def save_context(self, details, workflow_manage):
self.context['document_list'] = details.get('document_list', [])
self.context['question'] = details.get('question')
self.context['run_time'] = details.get('run_time')
self.context['result_list'] = details.get('result_list')
self.context['result'] = details.get('result')
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
documents = merge_reranker_list(reranker_list)
top_n = reranker_setting.get('top_n', 3)
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
document in documents]
self.context['question'] = question
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
self.flow_params_serializer.data.get('user_id'),
top_n=top_n)
result = reranker_model.compress_documents(
documents,
question)
similarity = reranker_setting.get('similarity', 0.6)
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
result = reset_result_list(result, documents)
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'document_list': self.context.get('document_list'),
"question": self.context.get('question'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
'result_list': self.context.get('result_list'),
'result': self.context.get('result'),
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:30
@desc:
"""
from .impl import *

View File

@ -0,0 +1,79 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_search_dataset_node.py
@date2024/6/3 17:52
@desc:
"""
import re
from typing import Type
from django.core import validators
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.common import flat_map
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class DatasetSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer(_("Reference segment number")))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
error_messages=ErrMessage.float(_('similarity')))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message=_("The type only supports embedding|keywords|blend"), code=500)
], error_messages=ErrMessage.char(_("Retrieval Mode")))
max_paragraph_char_number = serializers.IntegerField(required=True,
error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
class SearchDatasetStepNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list(_("Dataset id list")))
dataset_setting = DatasetSettingSerializer(required=True)
question_reference_address = serializers.ListField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
def get_paragraph_list(chat_record, node_id):
return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if
(chat_record.details[
key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get(
'paragraph_list', []) is not None and key == node_id])
class ISearchDatasetStepNode(INode):
type = 'search-dataset-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return SearchDatasetStepNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
exclude_paragraph_id_list = []
if self.flow_params_serializer.data.get('re_chat', False):
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
paragraph_id_list = [p.get('id') for p in flat_map(
[get_paragraph_list(chat_record, self.runtime_node_id) for chat_record in history_chat_record if
chat_record.problem_text == question])]
exclude_paragraph_id_list = list(set(paragraph_id_list))
return self.execute(**self.node_params_serializer.data, question=str(question),
exclude_paragraph_id_list=exclude_paragraph_id_list)
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_search_dataset_node import BaseSearchDatasetNode

View File

@ -0,0 +1,146 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_search_dataset_node.py
@date2024/6/4 11:56
@desc:
"""
import os
from typing import List, Dict
from django.db.models import QuerySet
from django.db import connection
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode
from setting.models_provider.tools import get_model_instance_by_model_user_id
from smartdoc.conf import PROJECT_DIR
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("关联知识库的向量模型不一致,无法召回分段。")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})
def reset_title(title):
if title is None or len(title.strip()) == 0:
return ""
else:
return f"#### {title}\n"
class BaseSearchDatasetNode(ISearchDatasetStepNode):
def save_context(self, details, workflow_manage):
result = details.get('paragraph_list', [])
dataset_setting = self.node_params_serializer.data.get('dataset_setting')
directly_return = '\n'.join(
[f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if
paragraph.get('is_hit_handling_method')])
self.context['paragraph_list'] = result
self.context['question'] = details.get('question')
self.context['run_time'] = details.get('run_time')
self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')]
self.context['data'] = '\n'.join(
[f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)]
self.context['directly_return'] = directly_return
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
# 手动关闭数据库连接
connection.close()
if embedding_list is None:
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
return NodeResult({'paragraph_list': result,
'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')],
'data': '\n'.join(
[f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)],
'directly_return': '\n'.join(
[paragraph.get('content') for paragraph in
result if
paragraph.get('is_hit_handling_method')]),
'question': question},
{})
@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List):
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return {
**paragraph,
'similarity': find_embedding.get('similarity'),
'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return',
'update_time': paragraph.get('update_time').strftime("%Y-%m-%d %H:%M:%S"),
'create_time': paragraph.get('create_time').strftime("%Y-%m-%d %H:%M:%S"),
'id': str(paragraph.get('id')),
'dataset_id': str(paragraph.get('dataset_id')),
'document_id': str(paragraph.get('document_id'))
}
@staticmethod
def list_paragraph(embedding_list: List, vector):
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [row.get('id') for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
return paragraph_list
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
'question': self.context.get('question'),
"index": index,
'run_time': self.context.get('run_time'),
'paragraph_list': self.context.get('paragraph_list'),
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
}

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,38 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
from django.utils.translation import gettext_lazy as _
class SpeechToTextNodeSerializer(serializers.Serializer):
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
audio_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("The audio file cannot be empty")))
class ISpeechToTextNode(INode):
type = 'speech-to-text-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return SpeechToTextNodeSerializer
def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
self.node_params_serializer.data.get('audio_list')[1:])
for audio in res:
if 'file_id' not in audio:
raise ValueError(_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails"))
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, stt_model_id, chat_id,
audio,
**kwargs) -> NodeResult:
pass

Some files were not shown because too many files have changed in this diff Show More