153 lines
4.4 KiB
Python
153 lines
4.4 KiB
Python
"""
|
||
Redis 客户端管理
|
||
"""
|
||
import redis.asyncio as aioredis
|
||
from typing import Optional
|
||
from app.core.config import settings
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis 客户端实例
|
||
redis_client: aioredis.Redis = None
|
||
|
||
|
||
async def init_redis():
|
||
"""初始化 Redis 连接"""
|
||
global redis_client
|
||
try:
|
||
redis_client = await aioredis.from_url(
|
||
settings.REDIS_URL,
|
||
encoding="utf-8",
|
||
decode_responses=True,
|
||
max_connections=10
|
||
)
|
||
# 测试连接
|
||
await redis_client.ping()
|
||
logger.info("Redis connected successfully")
|
||
except Exception as e:
|
||
logger.error(f"Redis connection failed: {e}")
|
||
raise
|
||
|
||
|
||
async def close_redis():
|
||
"""关闭 Redis 连接"""
|
||
global redis_client
|
||
if redis_client:
|
||
await redis_client.close()
|
||
logger.info("Redis connection closed")
|
||
|
||
|
||
def get_redis() -> aioredis.Redis:
|
||
"""获取 Redis 客户端"""
|
||
return redis_client
|
||
|
||
|
||
# Token 缓存相关操作
|
||
class TokenCache:
|
||
"""Token 缓存管理"""
|
||
|
||
TOKEN_PREFIX = "token:"
|
||
USER_TOKEN_PREFIX = "user_tokens:"
|
||
|
||
@staticmethod
|
||
async def save_token(user_id: int, token: str, expire_seconds: int = 86400):
|
||
"""
|
||
保存 token 到 Redis
|
||
:param user_id: 用户ID
|
||
:param token: JWT token
|
||
:param expire_seconds: 过期时间(秒),默认24小时
|
||
"""
|
||
redis = get_redis()
|
||
|
||
# 保存 token -> user_id 映射
|
||
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
||
await redis.setex(token_key, expire_seconds, str(user_id))
|
||
|
||
# 保存 user_id -> tokens 集合(支持多设备登录)
|
||
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
||
await redis.sadd(user_tokens_key, token)
|
||
await redis.expire(user_tokens_key, expire_seconds)
|
||
|
||
logger.info(f"Token saved for user {user_id}, expires in {expire_seconds}s")
|
||
|
||
@staticmethod
|
||
async def get_user_id(token: str) -> Optional[int]:
|
||
"""
|
||
根据 token 获取用户ID
|
||
:param token: JWT token
|
||
:return: 用户ID 或 None
|
||
"""
|
||
redis = get_redis()
|
||
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
||
user_id_str = await redis.get(token_key)
|
||
|
||
if user_id_str:
|
||
return int(user_id_str)
|
||
return None
|
||
|
||
@staticmethod
|
||
async def delete_token(token: str):
|
||
"""
|
||
删除指定 token
|
||
:param token: JWT token
|
||
"""
|
||
redis = get_redis()
|
||
|
||
# 获取 user_id
|
||
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
||
user_id_str = await redis.get(token_key)
|
||
|
||
if user_id_str:
|
||
user_id = int(user_id_str)
|
||
|
||
# 从用户 tokens 集合中移除
|
||
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
||
await redis.srem(user_tokens_key, token)
|
||
|
||
# 删除 token
|
||
await redis.delete(token_key)
|
||
logger.info(f"Token deleted: {token[:20]}...")
|
||
|
||
@staticmethod
|
||
async def delete_user_all_tokens(user_id: int):
|
||
"""
|
||
删除用户的所有 token(用于强制登出)
|
||
:param user_id: 用户ID
|
||
"""
|
||
redis = get_redis()
|
||
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
||
|
||
# 获取所有 tokens
|
||
tokens = await redis.smembers(user_tokens_key)
|
||
|
||
# 删除所有 token
|
||
for token in tokens:
|
||
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
||
await redis.delete(token_key)
|
||
|
||
# 删除集合
|
||
await redis.delete(user_tokens_key)
|
||
logger.info(f"All tokens deleted for user {user_id}")
|
||
|
||
@staticmethod
|
||
async def extend_token(token: str, expire_seconds: int = 86400):
|
||
"""
|
||
延长 token 有效期(用于 token 刷新)
|
||
:param token: JWT token
|
||
:param expire_seconds: 延长的过期时间(秒)
|
||
"""
|
||
redis = get_redis()
|
||
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
||
|
||
# 延长过期时间
|
||
await redis.expire(token_key, expire_seconds)
|
||
|
||
# 获取 user_id 并延长用户 tokens 集合过期时间
|
||
user_id_str = await redis.get(token_key)
|
||
if user_id_str:
|
||
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id_str}"
|
||
await redis.expire(user_tokens_key, expire_seconds)
|
||
|
||
logger.info(f"Token extended: {token[:20]}...")
|