diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 9f8e0d087..f44ffac2c 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -10,6 +10,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from common.config.embedding_config import ModelManage +from common.database_model_manage.database_model_manage import DatabaseModelManage from common.exception.app_exception import AppApiException from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt from models_provider.base_model_provider import ValidCode, DownModelChunkStatus @@ -394,7 +395,22 @@ class ModelSerializer(serializers.Serializer): return True -class SharedModelSerializer(serializers.Serializer): +def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization): + white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( + workspace_id=workspace_id, authentication_type='WHITE_LIST' + ).values_list('model_id', flat=True) + black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( + workspace_id=workspace_id, authentication_type='BLACK_LIST' + ).values_list('model_id', flat=True) + tool_query_set = tool_query_set.filter( + id__in=white_authorized_tool_ids + ).exclude( + id__in=black_authorized_tool_ids + ) + return tool_query_set + + +class WorkspaceSharedModelSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) name = serializers.CharField(required=False, max_length=64, label=_('model name')) model_type = serializers.CharField(required=False, label=_('model type')) @@ -404,7 +420,10 @@ class SharedModelSerializer(serializers.Serializer): def get_share_model_list(self): self.is_valid(raise_exception=True) - queryset = QuerySet(Model).filter(workspace_id='None') + workspace_id = self.data.get('workspace_id') + + queryset = self._build_queryset(workspace_id) + return [ { 'id': str(model.id), @@ -419,3 +438,23 @@ class SharedModelSerializer(serializers.Serializer): } for model in queryset.order_by("-create_time") ] + + def _build_queryset(self, workspace_id): + queryset = QuerySet(Model) + if workspace_id: + model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") + if model_workspace_authorization is not None: + queryset = get_authorized_tool(queryset, workspace_id, + model_workspace_authorization=model_workspace_authorization) + + for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: + value = self.data.get(field) + if value is not None: + if field == 'name': + queryset = queryset.filter(**{f'{field}__icontains': value}) + elif field == 'create_user': + queryset = queryset.filter(user_id=value) + else: + queryset = queryset.filter(**{field: value}) + + return queryset diff --git a/apps/models_provider/urls.py b/apps/models_provider/urls.py index 4f126d5aa..a5b25016b 100644 --- a/apps/models_provider/urls.py +++ b/apps/models_provider/urls.py @@ -18,7 +18,7 @@ urlpatterns = [ path('workspace//model//pause_download', views.ModelSetting.PauseDownload.as_view()), path('workspace//model//meta', views.ModelSetting.ModelMeta.as_view()), - path('workspace//shared/model', views.SharedModel.as_view()), + path('system/shared/workspace//model', views.WorkspaceSharedModelSetting.as_view()), ] if os.environ.get('SERVER_NAME', 'web') == 'local_model': diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index f23f8d539..02a77d4b9 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -21,7 +21,8 @@ from common.utils.common import query_params_to_single_dict from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse from models_provider.api.provide import ProvideApi from models_provider.models import Model -from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer +from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer, \ + WorkspaceSharedModelSerializer from system_manage.views import encryption_str @@ -65,7 +66,7 @@ class ModelSetting(APIView): request=ModelCreateAPI.get_request(), responses=ModelCreateAPI.get_response()) @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) @log(menu='model', operate='Create model', get_operation_object=lambda r, k: {'name': r.date.get('name')}, get_details=get_edit_model_details, @@ -95,7 +96,7 @@ class ModelSetting(APIView): responses=ModelListResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( @@ -114,7 +115,7 @@ class ModelSetting(APIView): responses=ModelEditApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) @log(menu='model', operate='Update model', get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), get_details=get_edit_model_details, @@ -133,7 +134,7 @@ class ModelSetting(APIView): responses=DefaultModelResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) @log(menu='model', operate='Delete model', get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), ) @@ -150,7 +151,7 @@ class ModelSetting(APIView): responses=GetModelApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.Operate( @@ -168,7 +169,7 @@ class ModelSetting(APIView): responses=ProvideApi.ModelParamsForm.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.ModelParams(data={'id': model_id}).get_model_params()) @@ -182,7 +183,7 @@ class ModelSetting(APIView): responses=ProvideApi.ModelParamsForm.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) @log(menu='model', operate='Save model parameter form', get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), ) @@ -204,7 +205,7 @@ class ModelSetting(APIView): responses=GetModelApi.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True)) @@ -221,25 +222,29 @@ class ModelSetting(APIView): responses=DefaultModelResponse.get_response(), tags=[_('Model')]) # type: ignore @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download()) -class SharedModel(APIView): +class WorkspaceSharedModelSetting(APIView): authentication_classes = [TokenAuth] @extend_schema( methods=['Get'], - summary=_('Get Share model'), - description=_('Get Share model'), - operation_id=_('Get Share model'), # type: ignore - parameters=ModelCreateAPI.get_parameters(), - responses=ModelListResponse.get_response(), + summary=_('Get Share model by workspace id'), + description=_('Get Share model by workspace id'), + operation_id=_('Get Share model by workspace id'), # type: ignore + parameters=ModelListResponse.get_parameters(), + responses=DefaultModelResponse.get_response(), tags=[_('Shared Model')] ) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ, RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_permission(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + RoleConstants.USER.get_workspace_role(), + ) def get(self, request: Request, workspace_id: str): return result.success( - SharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list()) + WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list())