142 lines
3.8 KiB
Python
142 lines
3.8 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime
|
|
import logging
|
|
import asyncio
|
|
|
|
from app.models.db import Task
|
|
from app.services.redis_cache import redis_cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskService:
|
|
def __init__(self):
|
|
self.redis_prefix = "task:progress:"
|
|
|
|
async def create_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_type: str,
|
|
description: str = None,
|
|
params: Dict[str, Any] = None,
|
|
created_by: int = None
|
|
) -> Task:
|
|
"""Create a new task record"""
|
|
task = Task(
|
|
task_type=task_type,
|
|
description=description,
|
|
params=params,
|
|
status="pending",
|
|
created_by=created_by,
|
|
progress=0
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
# Init Redis status
|
|
await self._update_redis(task.id, 0, "pending")
|
|
|
|
return task
|
|
|
|
async def update_progress(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: int,
|
|
progress: int,
|
|
status: str = "running"
|
|
):
|
|
"""Update task progress in DB and Redis"""
|
|
# Update DB
|
|
stmt = (
|
|
update(Task)
|
|
.where(Task.id == task_id)
|
|
.values(
|
|
progress=progress,
|
|
status=status,
|
|
started_at=datetime.utcnow() if status == "running" and progress == 0 else None
|
|
)
|
|
)
|
|
await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
# Update Redis for fast polling
|
|
await self._update_redis(task_id, progress, status)
|
|
|
|
async def complete_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: int,
|
|
result: Dict[str, Any] = None
|
|
):
|
|
"""Mark task as completed"""
|
|
stmt = (
|
|
update(Task)
|
|
.where(Task.id == task_id)
|
|
.values(
|
|
status="completed",
|
|
progress=100,
|
|
completed_at=datetime.utcnow(),
|
|
result=result
|
|
)
|
|
)
|
|
await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
await self._update_redis(task_id, 100, "completed")
|
|
|
|
async def fail_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: int,
|
|
error_message: str
|
|
):
|
|
"""Mark task as failed"""
|
|
stmt = (
|
|
update(Task)
|
|
.where(Task.id == task_id)
|
|
.values(
|
|
status="failed",
|
|
completed_at=datetime.utcnow(),
|
|
error_message=error_message
|
|
)
|
|
)
|
|
await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
await self._update_redis(task_id, -1, "failed", error=error_message)
|
|
|
|
async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]:
|
|
"""Get task from DB"""
|
|
result = await db.execute(select(Task).where(Task.id == task_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def _update_redis(
|
|
self,
|
|
task_id: int,
|
|
progress: int,
|
|
status: str,
|
|
error: str = None
|
|
):
|
|
"""Update transient state in Redis"""
|
|
key = f"{self.redis_prefix}{task_id}"
|
|
data = {
|
|
"id": task_id,
|
|
"progress": progress,
|
|
"status": status,
|
|
"updated_at": datetime.utcnow().isoformat()
|
|
}
|
|
if error:
|
|
data["error"] = error
|
|
|
|
# Set TTL for 1 hour
|
|
await redis_cache.set(key, data, ttl_seconds=3600)
|
|
|
|
async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]:
|
|
"""Get real-time progress from Redis"""
|
|
key = f"{self.redis_prefix}{task_id}"
|
|
return await redis_cache.get(key)
|
|
|
|
task_service = TaskService()
|