267 lines
8.6 KiB
Python
267 lines
8.6 KiB
Python
import logging
|
||
import json
|
||
import time
|
||
import uuid
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from app.models.project import ProjectMember
|
||
from app.core.redis_client import get_redis
|
||
from typing import List, Optional, Dict, Any
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class NotificationService:
|
||
# 通知过期时间:14天 (秒)
|
||
EXPIRATION_SECONDS = 14 * 24 * 60 * 60
|
||
|
||
def _get_order_key(self, user_id: int) -> str:
|
||
return f"notifications:order:{user_id}"
|
||
|
||
def _get_content_key(self, user_id: int) -> str:
|
||
return f"notifications:content:{user_id}"
|
||
|
||
async def create_notification(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: int,
|
||
title: str,
|
||
content: str = None,
|
||
type: str = "info",
|
||
category: str = "system",
|
||
link: str = None
|
||
) -> Dict[str, Any]:
|
||
"""创建单条通知 (写入 Redis)"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return None
|
||
|
||
timestamp = time.time()
|
||
notification_id = str(uuid.uuid4())
|
||
|
||
notification_data = {
|
||
"id": notification_id,
|
||
"user_id": user_id,
|
||
"title": title,
|
||
"content": content,
|
||
"type": type,
|
||
"category": category,
|
||
"link": link,
|
||
"is_read": False,
|
||
"created_at": timestamp
|
||
}
|
||
|
||
json_data = json.dumps(notification_data, ensure_ascii=False)
|
||
order_key = self._get_order_key(user_id)
|
||
content_key = self._get_content_key(user_id)
|
||
|
||
async with redis.pipeline() as pipe:
|
||
pipe.hset(content_key, notification_id, json_data)
|
||
pipe.zadd(order_key, {notification_id: timestamp})
|
||
min_score = timestamp - self.EXPIRATION_SECONDS
|
||
pipe.zremrangebyscore(order_key, "-inf", min_score)
|
||
pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400)
|
||
pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400)
|
||
await pipe.execute()
|
||
|
||
return notification_data
|
||
|
||
async def broadcast_system_notification(
|
||
self,
|
||
db: AsyncSession,
|
||
title: str,
|
||
content: str,
|
||
user_ids: List[int],
|
||
link: str = None
|
||
):
|
||
"""向指定多个用户发送系统通知"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return
|
||
|
||
timestamp = time.time()
|
||
|
||
async with redis.pipeline() as pipe:
|
||
for uid in user_ids:
|
||
notification_id = str(uuid.uuid4())
|
||
notification_data = {
|
||
"id": notification_id,
|
||
"user_id": uid,
|
||
"title": title,
|
||
"content": content,
|
||
"type": "info",
|
||
"category": "system",
|
||
"link": link,
|
||
"is_read": False,
|
||
"created_at": timestamp
|
||
}
|
||
json_data = json.dumps(notification_data, ensure_ascii=False)
|
||
|
||
order_key = self._get_order_key(uid)
|
||
content_key = self._get_content_key(uid)
|
||
|
||
pipe.hset(content_key, notification_id, json_data)
|
||
pipe.zadd(order_key, {notification_id: timestamp})
|
||
pipe.expire(order_key, self.EXPIRATION_SECONDS + 86400)
|
||
pipe.expire(content_key, self.EXPIRATION_SECONDS + 86400)
|
||
|
||
await pipe.execute()
|
||
|
||
async def notify_project_members(
|
||
self,
|
||
db: AsyncSession,
|
||
project_id: int,
|
||
exclude_user_id: int,
|
||
title: str,
|
||
content: str,
|
||
link: str = None,
|
||
category: str = "project"
|
||
):
|
||
result = await db.execute(
|
||
select(ProjectMember.user_id).where(
|
||
ProjectMember.project_id == project_id,
|
||
ProjectMember.user_id != exclude_user_id
|
||
)
|
||
)
|
||
member_ids = result.scalars().all()
|
||
|
||
if member_ids:
|
||
await self.broadcast_system_notification(
|
||
db,
|
||
title=title,
|
||
content=content,
|
||
user_ids=member_ids,
|
||
link=link
|
||
)
|
||
|
||
async def get_user_notifications(
|
||
self,
|
||
user_id: int,
|
||
limit: int = 50,
|
||
skip: int = 0,
|
||
unread_only: bool = False
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取用户通知列表"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return []
|
||
|
||
order_key = self._get_order_key(user_id)
|
||
content_key = self._get_content_key(user_id)
|
||
|
||
# 如果是查询全部且没有过滤,可以直接利用 ZREVRANGE 分页
|
||
if not unread_only:
|
||
ids = await redis.zrevrange(order_key, skip, skip + limit - 1)
|
||
if not ids:
|
||
return []
|
||
json_strings = await redis.hmget(content_key, ids)
|
||
|
||
notifications = []
|
||
ids_to_remove = []
|
||
for i, json_str in enumerate(json_strings):
|
||
if json_str:
|
||
try:
|
||
notifications.append(json.loads(json_str))
|
||
except: continue
|
||
else:
|
||
ids_to_remove.append(ids[i])
|
||
if ids_to_remove:
|
||
await redis.zrem(order_key, *ids_to_remove)
|
||
return notifications
|
||
|
||
# 如果需要过滤未读,由于 Redis Hash 不支持按值过滤,需要获取较多数据在内存中过滤
|
||
# 考虑到数据只保留 14 天,全量获取(最近几百条)在内存中处理是可行的
|
||
all_ids = await redis.zrevrange(order_key, 0, -1)
|
||
if not all_ids:
|
||
return []
|
||
|
||
all_jsons = await redis.hmget(content_key, all_ids)
|
||
|
||
notifications = []
|
||
for json_str in all_jsons:
|
||
if json_str:
|
||
try:
|
||
data = json.loads(json_str)
|
||
if data.get('is_read') is False:
|
||
notifications.append(data)
|
||
except: continue
|
||
|
||
# 手动分页
|
||
return notifications[skip : skip + limit]
|
||
|
||
async def get_unread_count(self, user_id: int) -> int:
|
||
"""获取未读通知数量"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return 0
|
||
|
||
content_key = self._get_content_key(user_id)
|
||
all_jsons = await redis.hvals(content_key)
|
||
count = 0
|
||
for js in all_jsons:
|
||
try:
|
||
if js:
|
||
data = json.loads(js)
|
||
if not data.get('is_read'):
|
||
count += 1
|
||
except:
|
||
pass
|
||
return count
|
||
|
||
async def mark_read(self, user_id: int, notification_id: str):
|
||
"""标记已读"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return
|
||
|
||
content_key = self._get_content_key(user_id)
|
||
|
||
json_str = await redis.hget(content_key, notification_id)
|
||
if json_str:
|
||
try:
|
||
data = json.loads(json_str)
|
||
data['is_read'] = True
|
||
await redis.hset(content_key, notification_id, json.dumps(data, ensure_ascii=False))
|
||
except:
|
||
pass
|
||
|
||
async def mark_all_read(self, user_id: int):
|
||
"""标记所有已读"""
|
||
redis = get_redis()
|
||
if not redis:
|
||
return
|
||
|
||
order_key = self._get_order_key(user_id)
|
||
content_key = self._get_content_key(user_id)
|
||
|
||
ids = await redis.zrange(order_key, 0, -1)
|
||
if not ids:
|
||
return
|
||
|
||
json_strings = await redis.hmget(content_key, ids)
|
||
|
||
updates = {}
|
||
for i, json_str in enumerate(json_strings):
|
||
if json_str:
|
||
try:
|
||
data = json.loads(json_str)
|
||
if not data.get('is_read'):
|
||
data['is_read'] = True
|
||
updates[ids[i]] = json.dumps(data, ensure_ascii=False)
|
||
except:
|
||
pass
|
||
|
||
if updates:
|
||
await redis.hset(content_key, mapping=updates)
|
||
|
||
async def delete_notification(self, user_id: int, notification_id: str):
|
||
redis = get_redis()
|
||
if not redis:
|
||
return
|
||
|
||
order_key = self._get_order_key(user_id)
|
||
content_key = self._get_content_key(user_id)
|
||
|
||
await redis.zrem(order_key, notification_id)
|
||
await redis.hdel(content_key, notification_id)
|
||
|
||
notification_service = NotificationService() |