cosmo_backend/app/services/task_service.py

142 lines
3.8 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from typing import Optional, Dict, Any
from datetime import datetime
import logging
import asyncio
from app.models.db import Task
from app.services.redis_cache import redis_cache
logger = logging.getLogger(__name__)
class TaskService:
def __init__(self):
self.redis_prefix = "task:progress:"
async def create_task(
self,
db: AsyncSession,
task_type: str,
description: str = None,
params: Dict[str, Any] = None,
created_by: int = None
) -> Task:
"""Create a new task record"""
task = Task(
task_type=task_type,
description=description,
params=params,
status="pending",
created_by=created_by,
progress=0
)
db.add(task)
await db.commit()
await db.refresh(task)
# Init Redis status
await self._update_redis(task.id, 0, "pending")
return task
async def update_progress(
self,
db: AsyncSession,
task_id: int,
progress: int,
status: str = "running"
):
"""Update task progress in DB and Redis"""
# Update DB
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
progress=progress,
status=status,
started_at=datetime.utcnow() if status == "running" and progress == 0 else None
)
)
await db.execute(stmt)
await db.commit()
# Update Redis for fast polling
await self._update_redis(task_id, progress, status)
async def complete_task(
self,
db: AsyncSession,
task_id: int,
result: Dict[str, Any] = None
):
"""Mark task as completed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="completed",
progress=100,
completed_at=datetime.utcnow(),
result=result
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, 100, "completed")
async def fail_task(
self,
db: AsyncSession,
task_id: int,
error_message: str
):
"""Mark task as failed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="failed",
completed_at=datetime.utcnow(),
error_message=error_message
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, -1, "failed", error=error_message)
async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]:
"""Get task from DB"""
result = await db.execute(select(Task).where(Task.id == task_id))
return result.scalar_one_or_none()
async def _update_redis(
self,
task_id: int,
progress: int,
status: str,
error: str = None
):
"""Update transient state in Redis"""
key = f"{self.redis_prefix}{task_id}"
data = {
"id": task_id,
"progress": progress,
"status": status,
"updated_at": datetime.utcnow().isoformat()
}
if error:
data["error"] = error
# Set TTL for 1 hour
await redis_cache.set(key, data, ttl_seconds=3600)
async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]:
"""Get real-time progress from Redis"""
key = f"{self.redis_prefix}{task_id}"
return await redis_cache.get(key)
task_service = TaskService()