from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from typing import Optional, Dict, Any from datetime import datetime import logging import asyncio from app.models.db import Task from app.services.redis_cache import redis_cache logger = logging.getLogger(__name__) class TaskService: def __init__(self): self.redis_prefix = "task:progress:" async def create_task( self, db: AsyncSession, task_type: str, description: str = None, params: Dict[str, Any] = None, created_by: int = None ) -> Task: """Create a new task record""" task = Task( task_type=task_type, description=description, params=params, status="pending", created_by=created_by, progress=0 ) db.add(task) await db.commit() await db.refresh(task) # Init Redis status await self._update_redis(task.id, 0, "pending") return task async def update_progress( self, db: AsyncSession, task_id: int, progress: int, status: str = "running" ): """Update task progress in DB and Redis""" # Update DB stmt = ( update(Task) .where(Task.id == task_id) .values( progress=progress, status=status, started_at=datetime.utcnow() if status == "running" and progress == 0 else None ) ) await db.execute(stmt) await db.commit() # Update Redis for fast polling await self._update_redis(task_id, progress, status) async def complete_task( self, db: AsyncSession, task_id: int, result: Dict[str, Any] = None ): """Mark task as completed""" stmt = ( update(Task) .where(Task.id == task_id) .values( status="completed", progress=100, completed_at=datetime.utcnow(), result=result ) ) await db.execute(stmt) await db.commit() await self._update_redis(task_id, 100, "completed") async def fail_task( self, db: AsyncSession, task_id: int, error_message: str ): """Mark task as failed""" stmt = ( update(Task) .where(Task.id == task_id) .values( status="failed", completed_at=datetime.utcnow(), error_message=error_message ) ) await db.execute(stmt) await db.commit() await self._update_redis(task_id, -1, "failed", error=error_message) async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]: """Get task from DB""" result = await db.execute(select(Task).where(Task.id == task_id)) return result.scalar_one_or_none() async def _update_redis( self, task_id: int, progress: int, status: str, error: str = None ): """Update transient state in Redis""" key = f"{self.redis_prefix}{task_id}" data = { "id": task_id, "progress": progress, "status": status, "updated_at": datetime.utcnow().isoformat() } if error: data["error"] = error # Set TTL for 1 hour await redis_cache.set(key, data, ttl_seconds=3600) async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]: """Get real-time progress from Redis""" key = f"{self.redis_prefix}{task_id}" return await redis_cache.get(key) task_service = TaskService()