From 6e0e0d2366b59d47ca1262eb7d0fd0df131cf541 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 9 Jun 2025 16:32:16 +0800 Subject: [PATCH] refactor: model api --- apps/models_provider/api/model.py | 48 +++++++++++++++- .../serializers/model_serializer.py | 57 +++++++++---------- apps/models_provider/views/model.py | 23 ++++---- 3 files changed, 85 insertions(+), 43 deletions(-) diff --git a/apps/models_provider/api/model.py b/apps/models_provider/api/model.py index c53985143..2b1e797dc 100644 --- a/apps/models_provider/api/model.py +++ b/apps/models_provider/api/model.py @@ -23,6 +23,52 @@ class ModelListResponse(APIMixin): return ModelListResult + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="workspace_id", + description=_("workspace id"), + type=OpenApiTypes.STR, + location=OpenApiParameter.PATH, + required=True, + ), + OpenApiParameter( + name="name", + description=_("model name"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + required=False, + ), + OpenApiParameter( + name="model_type", + description=_("model type"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + required=False, + ), + OpenApiParameter( + name="model_name", + description=_("base model"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + required=False, + ), + OpenApiParameter( + name="provider", + description=_("provider"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + required=False, + ), + OpenApiParameter( + name="create_user", + description=_("create user"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, + required=False, + ) + ] + class ModelCreateAPI(APIMixin): @staticmethod @@ -34,7 +80,7 @@ class ModelCreateAPI(APIMixin): return ModelCreateResponse @classmethod - def get_query_params_api(cls): + def get_parameters(cls): return [OpenApiParameter( name="workspace_id", description=_("workspace id"), diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 84990ea51..90739ae11 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -105,7 +105,7 @@ class ModelSerializer(serializers.Serializer): class Operate(serializers.Serializer): id = serializers.UUIDField(required=True, label=_("model id")) - user_id = serializers.UUIDField(required=True, label=_("user id")) + user_id = serializers.UUIDField(required=False, label=_("user id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -114,6 +114,8 @@ class ModelSerializer(serializers.Serializer): ).first() if model is None: raise AppApiException(500, _('Model does not exist')) + if model.workspace_id == 'None': + raise AppApiException(500, _('Shared models cannot be deleted or modified')) def one(self, with_valid=False): if with_valid: @@ -147,8 +149,6 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) model_id = self.data.get('id') model = Model.objects.filter(id=model_id).first() - if not model: - raise AppApiException(500, _("Model does not exist")) # TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系 # if model.model_type == 'LLM': # application_count = Application.objects.filter(model_id=model_id).count() @@ -174,35 +174,32 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) model = QuerySet(Model).filter(id=self.data.get('id')).first() - if model is None: - raise AppApiException(500, _('Model does not exist')) - else: - credential, model_credential, provider_handler = ModelSerializer.Edit( - data={**instance}).is_valid( - model=model) - try: - model.status = Status.SUCCESS - default_params = {item['field']: item['default_value'] for item in model.model_params_form} - # 校验模型认证数据 - provider_handler.is_valid_credential(model.model_type, - instance.get("model_name"), - credential, - default_params, - raise_exception=True) + credential, model_credential, provider_handler = ModelSerializer.Edit( + data={**instance}).is_valid( + model=model) + try: + model.status = Status.SUCCESS + default_params = {item['field']: item['default_value'] for item in model.model_params_form} + # 校验模型认证数据 + provider_handler.is_valid_credential(model.model_type, + instance.get("model_name"), + credential, + default_params, + raise_exception=True) - except AppApiException as e: - if e.code == ValidCode.model_not_fount: - model.status = Status.DOWNLOAD + except AppApiException as e: + if e.code == ValidCode.model_not_fount: + model.status = Status.DOWNLOAD + else: + raise e + update_keys = ['credential', 'name', 'model_type', 'model_name'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + if update_key == 'credential': + model_credential_str = json.dumps(credential) + model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: - raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] - for update_key in update_keys: - if update_key in instance and instance.get(update_key) is not None: - if update_key == 'credential': - model_credential_str = json.dumps(credential) - model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) - else: - model.__setattr__(update_key, instance.get(update_key)) + model.__setattr__(update_key, instance.get(update_key)) ModelManage.delete_key(str(model.id)) model.save() diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index c1ecc20d4..436445f0b 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -61,7 +61,7 @@ class ModelSetting(APIView): description=_("Create model"), operation_id=_("Create model"), # type: ignore tags=[_("Model")], # type: ignore - parameters=ModelCreateAPI.get_query_params_api(), + parameters=ModelCreateAPI.get_parameters(), request=ModelCreateAPI.get_request(), responses=ModelCreateAPI.get_response()) @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission()) @@ -90,7 +90,7 @@ class ModelSetting(APIView): summary=_('Query model list'), description=_('Query model list'), operation_id=_('Query model list'), # type: ignore - parameters=ModelCreateAPI.get_query_params_api(), + parameters=ModelListResponse.get_parameters(), responses=ModelListResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @@ -108,7 +108,7 @@ class ModelSetting(APIView): description=_('Update model'), operation_id=_('Update model'), # type: ignore request=ModelEditApi.get_request(), - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), responses=ModelEditApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission()) @@ -125,7 +125,7 @@ class ModelSetting(APIView): summary=_('Delete model'), description=_('Delete model'), operation_id=_('Delete model'), # type: ignore - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), responses=DefaultModelResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission()) @@ -139,7 +139,7 @@ class ModelSetting(APIView): summary=_('Query model details'), description=_('Query model details'), operation_id=_('Query model details'), # type: ignore - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), responses=GetModelApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @@ -154,7 +154,7 @@ class ModelSetting(APIView): summary=_('Get model parameter form'), description=_('Get model parameter form'), operation_id=_('Get model parameter form'), # type: ignore - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), responses=ProvideApi.ModelParamsForm.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @@ -166,7 +166,7 @@ class ModelSetting(APIView): summary=_('Save model parameter form'), description=_('Save model parameter form'), operation_id=_('Save model parameter form'), # type: ignore - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), request=GetModelApi.get_request(), responses=ProvideApi.ModelParamsForm.get_response(), tags=[_('Model')]) # type: ignore @@ -187,7 +187,7 @@ class ModelSetting(APIView): 'Query model meta information, this interface does not carry authentication information'), operation_id=_( 'Query model meta information, this interface does not carry authentication information'), - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), responses=GetModelApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @@ -202,7 +202,7 @@ class ModelSetting(APIView): summary=_('Pause model download'), description=_('Pause model download'), operation_id=_('Pause model download'), # type: ignore - parameters=GetModelApi.get_query_params_api(), + parameters=GetModelApi.get_parameters(), request=GetModelApi.get_request(), responses=DefaultModelResponse.get_response(), tags=[_('Model')]) # type: ignore @@ -218,9 +218,8 @@ class ModelSetting(APIView): summary=_('Get Share model'), description=_('Get Share model'), operation_id=_('Get Share model'), # type: ignore - parameters=GetModelApi.get_query_params_api(), - request=GetModelApi.get_request(), - responses=DefaultModelResponse.get_response(), + parameters=ModelListResponse.get_parameters(), + responses=ModelListResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) def get(self, request: Request, workspace_id: str):