nex_docus/backend/app/services/notification_service.py

267 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import json
import time
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.project import ProjectMember
from app.core.redis_client import get_redis
from typing import List, Optional, Dict, Any
logger = logging.getLogger(__name__)
class NotificationService:
# 通知过期时间14天 (秒)
EXPIRATION_SECONDS = 14 * 24 * 60 * 60
def _get_order_key(self, user_id: int) -> str:
return f"notifications:order:{user_id}"
def _get_content_key(self, user_id: int) -> str:
return f"notifications:content:{user_id}"
async def create_notification(
self,
db: AsyncSession,
user_id: int,
title: str,
content: str = None,
type: str = "info",
category: str = "system",
link: str = None
) -> Dict[str, Any]:
"""创建单条通知 (写入 Redis)"""
redis = get_redis()
if not redis:
return None
timestamp = time.time()
notification_id = str(uuid.uuid4())
notification_data = {
"id": notification_id,
"user_id": user_id,
"title": title,
"content": content,
"type": type,
"category": category,
"link": link,
"is_read": False,
"created_at": timestamp
}
json_data = json.dumps(notification_data, ensure_ascii=False)
order_key = self._get_order_key(user_id)
content_key = self._get_content_key(user_id)
async with redis.pipeline() as pipe:
pipe.hset(content_key, notification_id, json_data)
pipe.zadd(order_key, {notification_id: timestamp})
min_score = timestamp - self.EXPIRATION_SECONDS
pipe.zremrangebyscore(order_key, "-inf", min_score)
pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400)
pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400)
await pipe.execute()
return notification_data
async def broadcast_system_notification(
self,
db: AsyncSession,
title: str,
content: str,
user_ids: List[int],
link: str = None
):
"""向指定多个用户发送系统通知"""
redis = get_redis()
if not redis:
return
timestamp = time.time()
async with redis.pipeline() as pipe:
for uid in user_ids:
notification_id = str(uuid.uuid4())
notification_data = {
"id": notification_id,
"user_id": uid,
"title": title,
"content": content,
"type": "info",
"category": "system",
"link": link,
"is_read": False,
"created_at": timestamp
}
json_data = json.dumps(notification_data, ensure_ascii=False)
order_key = self._get_order_key(uid)
content_key = self._get_content_key(uid)
pipe.hset(content_key, notification_id, json_data)
pipe.zadd(order_key, {notification_id: timestamp})
pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400)
pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400)
await pipe.execute()
async def notify_project_members(
self,
db: AsyncSession,
project_id: int,
exclude_user_id: int,
title: str,
content: str,
link: str = None,
category: str = "project"
):
result = await db.execute(
select(ProjectMember.user_id).where(
ProjectMember.project_id == project_id,
ProjectMember.user_id != exclude_user_id
)
)
member_ids = result.scalars().all()
if member_ids:
await self.broadcast_system_notification(
db,
title=title,
content=content,
user_ids=member_ids,
link=link
)
async def get_user_notifications(
self,
user_id: int,
limit: int = 50,
skip: int = 0,
unread_only: bool = False
) -> List[Dict[str, Any]]:
"""获取用户通知列表"""
redis = get_redis()
if not redis:
return []
order_key = self._get_order_key(user_id)
content_key = self._get_content_key(user_id)
# 如果是查询全部且没有过滤,可以直接利用 ZREVRANGE 分页
if not unread_only:
ids = await redis.zrevrange(order_key, skip, skip + limit - 1)
if not ids:
return []
json_strings = await redis.hmget(content_key, ids)
notifications = []
ids_to_remove = []
for i, json_str in enumerate(json_strings):
if json_str:
try:
notifications.append(json.loads(json_str))
except: continue
else:
ids_to_remove.append(ids[i])
if ids_to_remove:
await redis.zrem(order_key, *ids_to_remove)
return notifications
# 如果需要过滤未读,由于 Redis Hash 不支持按值过滤,需要获取较多数据在内存中过滤
# 考虑到数据只保留 14 天,全量获取(最近几百条)在内存中处理是可行的
all_ids = await redis.zrevrange(order_key, 0, -1)
if not all_ids:
return []
all_jsons = await redis.hmget(content_key, all_ids)
notifications = []
for json_str in all_jsons:
if json_str:
try:
data = json.loads(json_str)
if data.get('is_read') is False:
notifications.append(data)
except: continue
# 手动分页
return notifications[skip : skip + limit]
async def get_unread_count(self, user_id: int) -> int:
"""获取未读通知数量"""
redis = get_redis()
if not redis:
return 0
content_key = self._get_content_key(user_id)
all_jsons = await redis.hvals(content_key)
count = 0
for js in all_jsons:
try:
if js:
data = json.loads(js)
if not data.get('is_read'):
count += 1
except:
pass
return count
async def mark_read(self, user_id: int, notification_id: str):
"""标记已读"""
redis = get_redis()
if not redis:
return
content_key = self._get_content_key(user_id)
json_str = await redis.hget(content_key, notification_id)
if json_str:
try:
data = json.loads(json_str)
data['is_read'] = True
await redis.hset(content_key, notification_id, json.dumps(data, ensure_ascii=False))
except:
pass
async def mark_all_read(self, user_id: int):
"""标记所有已读"""
redis = get_redis()
if not redis:
return
order_key = self._get_order_key(user_id)
content_key = self._get_content_key(user_id)
ids = await redis.zrange(order_key, 0, -1)
if not ids:
return
json_strings = await redis.hmget(content_key, ids)
updates = {}
for i, json_str in enumerate(json_strings):
if json_str:
try:
data = json.loads(json_str)
if not data.get('is_read'):
data['is_read'] = True
updates[ids[i]] = json.dumps(data, ensure_ascii=False)
except:
pass
if updates:
await redis.hset(content_key, mapping=updates)
async def delete_notification(self, user_id: int, notification_id: str):
redis = get_redis()
if not redis:
return
order_key = self._get_order_key(user_id)
content_key = self._get_content_key(user_id)
await redis.zrem(order_key, notification_id)
await redis.hdel(content_key, notification_id)
notification_service = NotificationService()