import logging import json import time import uuid from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.models.project import ProjectMember from app.core.redis_client import get_redis from typing import List, Optional, Dict, Any logger = logging.getLogger(__name__) class NotificationService: # 通知过期时间:14天 (秒) EXPIRATION_SECONDS = 14 * 24 * 60 * 60 def _get_order_key(self, user_id: int) -> str: return f"notifications:order:{user_id}" def _get_content_key(self, user_id: int) -> str: return f"notifications:content:{user_id}" async def create_notification( self, db: AsyncSession, user_id: int, title: str, content: str = None, type: str = "info", category: str = "system", link: str = None ) -> Dict[str, Any]: """创建单条通知 (写入 Redis)""" redis = get_redis() if not redis: return None timestamp = time.time() notification_id = str(uuid.uuid4()) notification_data = { "id": notification_id, "user_id": user_id, "title": title, "content": content, "type": type, "category": category, "link": link, "is_read": False, "created_at": timestamp } json_data = json.dumps(notification_data, ensure_ascii=False) order_key = self._get_order_key(user_id) content_key = self._get_content_key(user_id) async with redis.pipeline() as pipe: pipe.hset(content_key, notification_id, json_data) pipe.zadd(order_key, {notification_id: timestamp}) min_score = timestamp - self.EXPIRATION_SECONDS pipe.zremrangebyscore(order_key, "-inf", min_score) pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400) pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400) await pipe.execute() return notification_data async def broadcast_system_notification( self, db: AsyncSession, title: str, content: str, user_ids: List[int], link: str = None ): """向指定多个用户发送系统通知""" redis = get_redis() if not redis: return timestamp = time.time() async with redis.pipeline() as pipe: for uid in user_ids: notification_id = str(uuid.uuid4()) notification_data = { "id": notification_id, "user_id": uid, "title": title, "content": content, "type": "info", "category": "system", "link": link, "is_read": False, "created_at": timestamp } json_data = json.dumps(notification_data, ensure_ascii=False) order_key = self._get_order_key(uid) content_key = self._get_content_key(uid) pipe.hset(content_key, notification_id, json_data) pipe.zadd(order_key, {notification_id: timestamp}) pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400) pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400) await pipe.execute() async def notify_project_members( self, db: AsyncSession, project_id: int, exclude_user_id: int, title: str, content: str, link: str = None, category: str = "project" ): result = await db.execute( select(ProjectMember.user_id).where( ProjectMember.project_id == project_id, ProjectMember.user_id != exclude_user_id ) ) member_ids = result.scalars().all() if member_ids: await self.broadcast_system_notification( db, title=title, content=content, user_ids=member_ids, link=link ) async def get_user_notifications( self, user_id: int, limit: int = 50, skip: int = 0, unread_only: bool = False ) -> List[Dict[str, Any]]: """获取用户通知列表""" redis = get_redis() if not redis: return [] order_key = self._get_order_key(user_id) content_key = self._get_content_key(user_id) # 如果是查询全部且没有过滤,可以直接利用 ZREVRANGE 分页 if not unread_only: ids = await redis.zrevrange(order_key, skip, skip + limit - 1) if not ids: return [] json_strings = await redis.hmget(content_key, ids) notifications = [] ids_to_remove = [] for i, json_str in enumerate(json_strings): if json_str: try: notifications.append(json.loads(json_str)) except: continue else: ids_to_remove.append(ids[i]) if ids_to_remove: await redis.zrem(order_key, *ids_to_remove) return notifications # 如果需要过滤未读,由于 Redis Hash 不支持按值过滤,需要获取较多数据在内存中过滤 # 考虑到数据只保留 14 天,全量获取(最近几百条)在内存中处理是可行的 all_ids = await redis.zrevrange(order_key, 0, -1) if not all_ids: return [] all_jsons = await redis.hmget(content_key, all_ids) notifications = [] for json_str in all_jsons: if json_str: try: data = json.loads(json_str) if data.get('is_read') is False: notifications.append(data) except: continue # 手动分页 return notifications[skip : skip + limit] async def get_unread_count(self, user_id: int) -> int: """获取未读通知数量""" redis = get_redis() if not redis: return 0 content_key = self._get_content_key(user_id) all_jsons = await redis.hvals(content_key) count = 0 for js in all_jsons: try: if js: data = json.loads(js) if not data.get('is_read'): count += 1 except: pass return count async def mark_read(self, user_id: int, notification_id: str): """标记已读""" redis = get_redis() if not redis: return content_key = self._get_content_key(user_id) json_str = await redis.hget(content_key, notification_id) if json_str: try: data = json.loads(json_str) data['is_read'] = True await redis.hset(content_key, notification_id, json.dumps(data, ensure_ascii=False)) except: pass async def mark_all_read(self, user_id: int): """标记所有已读""" redis = get_redis() if not redis: return order_key = self._get_order_key(user_id) content_key = self._get_content_key(user_id) ids = await redis.zrange(order_key, 0, -1) if not ids: return json_strings = await redis.hmget(content_key, ids) updates = {} for i, json_str in enumerate(json_strings): if json_str: try: data = json.loads(json_str) if not data.get('is_read'): data['is_read'] = True updates[ids[i]] = json.dumps(data, ensure_ascii=False) except: pass if updates: await redis.hset(content_key, mapping=updates) async def delete_notification(self, user_id: int, notification_id: str): redis = get_redis() if not redis: return order_key = self._get_order_key(user_id) content_key = self._get_content_key(user_id) await redis.zrem(order_key, notification_id) await redis.hdel(content_key, notification_id) notification_service = NotificationService()