cosmo_backend/app/services/token_service.py

137 lines
3.9 KiB
Python

"""
Token management service using Redis
"""
from typing import Optional
from datetime import timedelta
from app.services.redis_cache import redis_cache
from app.config import settings
import json
class TokenService:
"""Token management with Redis"""
def __init__(self):
self.prefix = "token:"
self.blacklist_prefix = "token:blacklist:"
self.user_tokens_prefix = "user:tokens:"
async def save_token(self, token: str, user_id: int, username: str) -> None:
"""
Save token to Redis with user info
Args:
token: JWT access token
user_id: User ID
username: Username
"""
# Save token with user info
token_data = {
"user_id": user_id,
"username": username
}
# Set token in Redis with TTL (24 hours)
ttl_seconds = settings.jwt_access_token_expire_minutes * 60
await redis_cache.set(
f"{self.prefix}{token}",
json.dumps(token_data),
ttl_seconds=ttl_seconds
)
# Track user's active tokens (for multi-device support)
user_tokens_key = f"{self.user_tokens_prefix}{user_id}"
# Add token to user's token set
if redis_cache.client:
await redis_cache.client.sadd(user_tokens_key, token)
await redis_cache.client.expire(user_tokens_key, ttl_seconds)
async def get_token_data(self, token: str) -> Optional[dict]:
"""
Get token data from Redis
Args:
token: JWT access token
Returns:
Token data dict or None if not found/expired
"""
# Check if token is blacklisted
is_blacklisted = await redis_cache.exists(f"{self.blacklist_prefix}{token}")
if is_blacklisted:
return None
# Get token data
data = await redis_cache.get(f"{self.prefix}{token}")
if data:
return json.loads(data)
return None
async def revoke_token(self, token: str) -> None:
"""
Revoke a token (logout)
Args:
token: JWT access token
"""
# Get token data first to know user_id
token_data = await self.get_token_data(token)
# Add to blacklist
ttl_seconds = settings.jwt_access_token_expire_minutes * 60
await redis_cache.set(
f"{self.blacklist_prefix}{token}",
"1",
expire=ttl_seconds
)
# Delete from active tokens
await redis_cache.delete(f"{self.prefix}{token}")
# Remove from user's token set
if token_data and redis_cache.client:
user_id = token_data.get("user_id")
if user_id:
await redis_cache.client.srem(
f"{self.user_tokens_prefix}{user_id}",
token
)
async def revoke_all_user_tokens(self, user_id: int) -> None:
"""
Revoke all tokens for a user (logout from all devices)
Args:
user_id: User ID
"""
if not redis_cache.client:
return
# Get all user's tokens
user_tokens_key = f"{self.user_tokens_prefix}{user_id}"
tokens = await redis_cache.client.smembers(user_tokens_key)
# Revoke each token
for token in tokens:
await self.revoke_token(token.decode() if isinstance(token, bytes) else token)
# Clear user's token set
await redis_cache.delete(user_tokens_key)
async def is_token_valid(self, token: str) -> bool:
"""
Check if token is valid (not blacklisted and exists in Redis)
Args:
token: JWT access token
Returns:
True if valid, False otherwise
"""
token_data = await self.get_token_data(token)
return token_data is not None
# Global token service instance
token_service = TokenService()