UnisKB/apps/setting/views/model.py

124 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
"""
@project: maxkb
@Author
@file model.py
@date2023/11/2 13:55
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import PermissionConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi
class Model(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建模型",
operation_id="创建模型",
request_body=ModelCreateApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def post(self, request: Request):
return result.success(
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
manual_parameters=ModelQueryApi.get_request_params_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
return result.success(
ModelSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
with_valid=True))
class Provide(APIView):
authentication_classes = [TokenAuth]
class Exec(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据",
operation_id="调用供应商函数,获取表单数据",
manual_parameters=ProvideApi.get_request_params_api(),
request_body=ProvideApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def post(self, request: Request, provider: str, method: str):
return result.success(
ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型供应商数据",
operation_id="获取模型供应商列表"
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
return result.success(
[ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in
ModelProvideConstants.__members__])
class ModelTypeList(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型类型列表",
operation_id="获取模型类型类型列表",
manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(),
responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api())
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
return result.success(ModelProvideConstants[provider].value.get_model_type_list())
class ModelList(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型创建表单",
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
, tags=["模型"]
)
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
return result.success(
ModelProvideConstants[provider].value.get_model_list(
model_type))
class ModelForm(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型创建表单",
operation_id="获取模型创建表单",
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
model_name = request.query_params.get('model_name')
return result.success(
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())