137 lines
3.9 KiB
Python
137 lines
3.9 KiB
Python
"""
|
|
Token management service using Redis
|
|
"""
|
|
from typing import Optional
|
|
from datetime import timedelta
|
|
from app.services.redis_cache import redis_cache
|
|
from app.config import settings
|
|
import json
|
|
|
|
|
|
class TokenService:
|
|
"""Token management with Redis"""
|
|
|
|
def __init__(self):
|
|
self.prefix = "token:"
|
|
self.blacklist_prefix = "token:blacklist:"
|
|
self.user_tokens_prefix = "user:tokens:"
|
|
|
|
async def save_token(self, token: str, user_id: int, username: str) -> None:
|
|
"""
|
|
Save token to Redis with user info
|
|
|
|
Args:
|
|
token: JWT access token
|
|
user_id: User ID
|
|
username: Username
|
|
"""
|
|
# Save token with user info
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"username": username
|
|
}
|
|
|
|
# Set token in Redis with TTL (24 hours)
|
|
ttl_seconds = settings.jwt_access_token_expire_minutes * 60
|
|
await redis_cache.set(
|
|
f"{self.prefix}{token}",
|
|
json.dumps(token_data),
|
|
ttl_seconds=ttl_seconds
|
|
)
|
|
|
|
# Track user's active tokens (for multi-device support)
|
|
user_tokens_key = f"{self.user_tokens_prefix}{user_id}"
|
|
# Add token to user's token set
|
|
if redis_cache.client:
|
|
await redis_cache.client.sadd(user_tokens_key, token)
|
|
await redis_cache.client.expire(user_tokens_key, ttl_seconds)
|
|
|
|
async def get_token_data(self, token: str) -> Optional[dict]:
|
|
"""
|
|
Get token data from Redis
|
|
|
|
Args:
|
|
token: JWT access token
|
|
|
|
Returns:
|
|
Token data dict or None if not found/expired
|
|
"""
|
|
# Check if token is blacklisted
|
|
is_blacklisted = await redis_cache.exists(f"{self.blacklist_prefix}{token}")
|
|
if is_blacklisted:
|
|
return None
|
|
|
|
# Get token data
|
|
data = await redis_cache.get(f"{self.prefix}{token}")
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
|
|
async def revoke_token(self, token: str) -> None:
|
|
"""
|
|
Revoke a token (logout)
|
|
|
|
Args:
|
|
token: JWT access token
|
|
"""
|
|
# Get token data first to know user_id
|
|
token_data = await self.get_token_data(token)
|
|
|
|
# Add to blacklist
|
|
ttl_seconds = settings.jwt_access_token_expire_minutes * 60
|
|
await redis_cache.set(
|
|
f"{self.blacklist_prefix}{token}",
|
|
"1",
|
|
expire=ttl_seconds
|
|
)
|
|
|
|
# Delete from active tokens
|
|
await redis_cache.delete(f"{self.prefix}{token}")
|
|
|
|
# Remove from user's token set
|
|
if token_data and redis_cache.client:
|
|
user_id = token_data.get("user_id")
|
|
if user_id:
|
|
await redis_cache.client.srem(
|
|
f"{self.user_tokens_prefix}{user_id}",
|
|
token
|
|
)
|
|
|
|
async def revoke_all_user_tokens(self, user_id: int) -> None:
|
|
"""
|
|
Revoke all tokens for a user (logout from all devices)
|
|
|
|
Args:
|
|
user_id: User ID
|
|
"""
|
|
if not redis_cache.client:
|
|
return
|
|
|
|
# Get all user's tokens
|
|
user_tokens_key = f"{self.user_tokens_prefix}{user_id}"
|
|
tokens = await redis_cache.client.smembers(user_tokens_key)
|
|
|
|
# Revoke each token
|
|
for token in tokens:
|
|
await self.revoke_token(token.decode() if isinstance(token, bytes) else token)
|
|
|
|
# Clear user's token set
|
|
await redis_cache.delete(user_tokens_key)
|
|
|
|
async def is_token_valid(self, token: str) -> bool:
|
|
"""
|
|
Check if token is valid (not blacklisted and exists in Redis)
|
|
|
|
Args:
|
|
token: JWT access token
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
token_data = await self.get_token_data(token)
|
|
return token_data is not None
|
|
|
|
|
|
# Global token service instance
|
|
token_service = TokenService()
|