nex_docus/backend/app/core/redis_client.py

153 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Redis 客户端管理
"""
import redis.asyncio as aioredis
from typing import Optional
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
# Redis 客户端实例
redis_client: aioredis.Redis = None
async def init_redis():
"""初始化 Redis 连接"""
global redis_client
try:
redis_client = await aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
max_connections=10
)
# 测试连接
await redis_client.ping()
logger.info("Redis connected successfully")
except Exception as e:
logger.error(f"Redis connection failed: {e}")
raise
async def close_redis():
"""关闭 Redis 连接"""
global redis_client
if redis_client:
await redis_client.close()
logger.info("Redis connection closed")
def get_redis() -> aioredis.Redis:
"""获取 Redis 客户端"""
return redis_client
# Token 缓存相关操作
class TokenCache:
"""Token 缓存管理"""
TOKEN_PREFIX = "token:"
USER_TOKEN_PREFIX = "user_tokens:"
@staticmethod
async def save_token(user_id: int, token: str, expire_seconds: int = 86400):
"""
保存 token 到 Redis
:param user_id: 用户ID
:param token: JWT token
:param expire_seconds: 过期时间默认24小时
"""
redis = get_redis()
# 保存 token -> user_id 映射
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
await redis.setex(token_key, expire_seconds, str(user_id))
# 保存 user_id -> tokens 集合(支持多设备登录)
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
await redis.sadd(user_tokens_key, token)
await redis.expire(user_tokens_key, expire_seconds)
logger.info(f"Token saved for user {user_id}, expires in {expire_seconds}s")
@staticmethod
async def get_user_id(token: str) -> Optional[int]:
"""
根据 token 获取用户ID
:param token: JWT token
:return: 用户ID 或 None
"""
redis = get_redis()
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
user_id_str = await redis.get(token_key)
if user_id_str:
return int(user_id_str)
return None
@staticmethod
async def delete_token(token: str):
"""
删除指定 token
:param token: JWT token
"""
redis = get_redis()
# 获取 user_id
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
user_id_str = await redis.get(token_key)
if user_id_str:
user_id = int(user_id_str)
# 从用户 tokens 集合中移除
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
await redis.srem(user_tokens_key, token)
# 删除 token
await redis.delete(token_key)
logger.info(f"Token deleted: {token[:20]}...")
@staticmethod
async def delete_user_all_tokens(user_id: int):
"""
删除用户的所有 token用于强制登出
:param user_id: 用户ID
"""
redis = get_redis()
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
# 获取所有 tokens
tokens = await redis.smembers(user_tokens_key)
# 删除所有 token
for token in tokens:
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
await redis.delete(token_key)
# 删除集合
await redis.delete(user_tokens_key)
logger.info(f"All tokens deleted for user {user_id}")
@staticmethod
async def extend_token(token: str, expire_seconds: int = 86400):
"""
延长 token 有效期(用于 token 刷新)
:param token: JWT token
:param expire_seconds: 延长的过期时间(秒)
"""
redis = get_redis()
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
# 延长过期时间
await redis.expire(token_key, expire_seconds)
# 获取 user_id 并延长用户 tokens 集合过期时间
user_id_str = await redis.get(token_key)
if user_id_str:
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id_str}"
await redis.expire(user_tokens_key, expire_seconds)
logger.info(f"Token extended: {token[:20]}...")