85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
from datetime import timedelta
|
|
from uuid import uuid4
|
|
from sqlalchemy.orm import Session
|
|
from redis import Redis
|
|
from app.core.security import verify_password, create_token, decode_token
|
|
from app.core.config import get_settings
|
|
from app.models import User
|
|
from app.models.enums import StatusEnum
|
|
from app.services.param_service import get_token_minutes
|
|
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
def authenticate_user(db: Session, username: str, password: str) -> User | None:
|
|
user = (
|
|
db.query(User)
|
|
.filter(User.username == username, User.status == int(StatusEnum.ENABLED), User.is_deleted.is_(False))
|
|
.first()
|
|
)
|
|
if not user:
|
|
return None
|
|
if not verify_password(password, user.password_hash):
|
|
return None
|
|
return user
|
|
|
|
|
|
def create_token_pair(db: Session, redis: Redis, user: User) -> tuple[str, str]:
|
|
access_minutes, refresh_minutes = get_token_minutes(
|
|
db,
|
|
default_access=settings.jwt_access_token_minutes,
|
|
default_refresh=settings.jwt_refresh_token_minutes,
|
|
)
|
|
|
|
access_payload = {"sub": str(user.user_id), "type": "access"}
|
|
refresh_jti = str(uuid4())
|
|
refresh_payload = {"sub": str(user.user_id), "type": "refresh", "jti": refresh_jti}
|
|
|
|
access_token = create_token(access_payload, timedelta(minutes=access_minutes))
|
|
refresh_token = create_token(refresh_payload, timedelta(minutes=refresh_minutes))
|
|
|
|
redis_key = f"auth:refresh:{refresh_jti}"
|
|
redis.setex(redis_key, refresh_minutes * 60, str(user.user_id))
|
|
# Track online user with TTL synced to refresh token
|
|
online_key = f"auth:online:{user.user_id}"
|
|
redis.setex(online_key, refresh_minutes * 60, "1")
|
|
|
|
return access_token, refresh_token
|
|
|
|
|
|
def refresh_access_token(db: Session, redis: Redis, refresh_token: str) -> tuple[str, str]:
|
|
payload = decode_token(refresh_token)
|
|
if payload.get("type") != "refresh":
|
|
raise ValueError("Invalid token type")
|
|
jti = payload.get("jti")
|
|
if not jti:
|
|
raise ValueError("Invalid token")
|
|
|
|
redis_key = f"auth:refresh:{jti}"
|
|
user_id = redis.get(redis_key)
|
|
if not user_id:
|
|
raise ValueError("Refresh token expired")
|
|
|
|
user = db.query(User).filter(User.user_id == int(user_id), User.is_deleted.is_(False)).first()
|
|
if not user or user.status != int(StatusEnum.ENABLED):
|
|
raise ValueError("User not found")
|
|
|
|
# Invalidate the old refresh token to avoid stale online counts
|
|
redis.delete(redis_key)
|
|
return create_token_pair(db, redis, user)
|
|
|
|
|
|
def logout_refresh_token(redis: Redis, refresh_token: str) -> None:
|
|
payload = decode_token(refresh_token)
|
|
if payload.get("type") != "refresh":
|
|
raise ValueError("Invalid token type")
|
|
jti = payload.get("jti")
|
|
if not jti:
|
|
raise ValueError("Invalid token")
|
|
redis_key = f"auth:refresh:{jti}"
|
|
user_id = redis.get(redis_key)
|
|
redis.delete(redis_key)
|
|
if user_id:
|
|
redis.delete(f"auth:online:{user_id}")
|