diff --git a/apps/models_provider/urls.py b/apps/models_provider/urls.py index 1d45eadf1..e1e7ed166 100644 --- a/apps/models_provider/urls.py +++ b/apps/models_provider/urls.py @@ -11,11 +11,11 @@ urlpatterns = [ path('provider/model_list', views.Provide.ModelList.as_view()), path('provider/model_params_form', views.Provide.ModelParamsForm.as_view()), path('provider/model_form', views.Provide.ModelForm.as_view()), - path('workspace//model', views.Model.as_view()), - path('workspace//model//model_params_form', views.Model.ModelParamsForm.as_view()), - path('workspace//model/', views.Model.Operate.as_view()), - path('workspace//model//pause_download', views.Model.PauseDownload.as_view()), - path('workspace//model//meta', views.Model.ModelMeta.as_view()), + path('workspace//model', views.ModelSetting.as_view()), + path('workspace//model//model_params_form', views.ModelSetting.ModelParamsForm.as_view()), + path('workspace//model/', views.ModelSetting.Operate.as_view()), + path('workspace//model//pause_download', views.ModelSetting.PauseDownload.as_view()), + path('workspace//model//meta', views.ModelSetting.ModelMeta.as_view()), ] if os.environ.get('SERVER_NAME', 'web') == 'local_model': diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index 62e6a20df..18451a028 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -31,7 +31,6 @@ def encryption_credential(credential): return credential - def get_edit_model_details(request): path = request.path body = request.data @@ -40,20 +39,21 @@ def get_edit_model_details(request): credential_encryption_ed = encryption_credential(credential) return { 'path': path, - 'body': {**body, 'credential':credential_encryption_ed}, + 'body': {**body, 'credential': credential_encryption_ed}, 'query': query } + def get_model_operation_object(model_id): model_model = QuerySet(model=Model).filter(id=model_id).first() if model_model is not None: return { - "name":model_model.name + "name": model_model.name } return {} -class Model(APIView): +class ModelSetting(APIView): authentication_classes = [TokenAuth] @extend_schema(methods=['POST'], @@ -66,7 +66,7 @@ class Model(APIView): responses=ModelCreateAPI.get_response()) @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission()) @log(menu='model', operate='Create model', - get_operation_object=lambda r,k: {'name': r.date.get('name')}, + get_operation_object=lambda r, k: {'name': r.date.get('name')}, get_details=get_edit_model_details ) def post(self, request: Request, workspace_id: str): @@ -113,7 +113,7 @@ class Model(APIView): tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission()) @log(menu='model', operate='Update model', - get_operation_object=lambda r,k: get_model_operation_object(k.get('model_id')), + get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), get_details=get_edit_model_details ) def put(self, request: Request, workspace_id, model_id: str): @@ -172,7 +172,7 @@ class Model(APIView): tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) @log(menu='model', operate='Save model parameter form', - get_operation_object=lambda r,k: get_model_operation_object(k.get('model_id'))) + get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id'))) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data)) diff --git a/apps/models_provider/views/model_apply.py b/apps/models_provider/views/model_apply.py index d7e691c33..e09a6923d 100644 --- a/apps/models_provider/views/model_apply.py +++ b/apps/models_provider/views/model_apply.py @@ -28,7 +28,7 @@ class ModelApply(APIView): responses=DefaultModelResponse.get_response(), tags=[_('Model')] # type: ignore ) - def post(self, request: Request, model_id): + def post(self, request: Request, workspace_id, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) @@ -40,7 +40,7 @@ class ModelApply(APIView): responses=DefaultModelResponse.get_response(), tags=[_('Model')] # type: ignore ) - def post(self, request: Request, model_id): + def post(self, request: Request, workspace_id, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) @@ -52,6 +52,6 @@ class ModelApply(APIView): responses=DefaultModelResponse.get_response(), tags=[_('Model')] # type: ignore ) - def post(self, request: Request, model_id): + def post(self, request: Request, workspace_id, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))