From e0cf0ce5981f89018c89120f07ba175e9cea7a49 Mon Sep 17 00:00:00 2001 From: "mula.liu" Date: Sun, 30 Nov 2025 23:58:37 +0800 Subject: [PATCH] feat: Implement background tasks for NASA data download --- backend/app/api/routes.py | 203 ++++++++++----------------- backend/app/models/db/__init__.py | 2 + backend/app/models/db/task.py | 26 ++++ backend/app/services/nasa_worker.py | 120 ++++++++++++++++ backend/app/services/task_service.py | 141 +++++++++++++++++++ backend/scripts/add_tasks_table.sql | 19 +++ 6 files changed, 384 insertions(+), 127 deletions(-) create mode 100644 backend/app/models/db/task.py create mode 100644 backend/app/services/nasa_worker.py create mode 100644 backend/app/services/task_service.py create mode 100644 backend/scripts/add_tasks_table.sql diff --git a/backend/app/api/routes.py b/backend/app/api/routes.py index 09ba632..2ac0037 100644 --- a/backend/app/api/routes.py +++ b/backend/app/api/routes.py @@ -2,7 +2,7 @@ API routes for celestial data """ from datetime import datetime -from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status +from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from typing import Optional, Dict, Any import logging @@ -12,7 +12,7 @@ from app.models.celestial import ( CelestialDataResponse, BodyInfo, ) -from app.models.db import Resource +from app.models.db import Resource, Task from app.services.horizons import horizons_service from app.services.cache import cache_service from app.services.redis_cache import redis_cache, make_cache_key, get_ttl_seconds @@ -26,6 +26,8 @@ from app.services.db_service import ( ) from app.services.orbit_service import orbit_service from app.services.system_settings_service import system_settings_service +from app.services.task_service import task_service +from app.services.nasa_worker import download_positions_task from app.database import get_db logger = logging.getLogger(__name__) @@ -1328,136 +1330,83 @@ class DownloadPositionRequest(BaseModel): @router.post("/positions/download") async def download_positions( request: DownloadPositionRequest, + background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """ - Download position data for specified bodies on specified dates - - This endpoint will: - 1. Query NASA Horizons API for the position at 00:00:00 UTC on each date - 2. Save the data to the positions table - 3. Return the downloaded data - - Args: - - body_ids: List of celestial body IDs - - dates: List of dates (YYYY-MM-DD format) - - Returns: - - Summary of downloaded data with success/failure status + Start asynchronous background task to download position data """ - logger.info(f"Downloading positions for {len(request.body_ids)} bodies on {len(request.dates)} dates") + # Create task record + task = await task_service.create_task( + db, + task_type="nasa_download", + description=f"Download positions for {len(request.body_ids)} bodies on {len(request.dates)} dates", + params=request.dict(), + created_by=None + ) + + # Add to background tasks + background_tasks.add_task( + download_positions_task, + task.id, + request.body_ids, + request.dates + ) + + return { + "message": "Download task started", + "task_id": task.id + } - try: - results = [] - total_success = 0 - total_failed = 0 - for body_id in request.body_ids: - # Check if body exists - body = await celestial_body_service.get_body_by_id(body_id, db) - if not body: - results.append({ - "body_id": body_id, - "status": "failed", - "error": "Body not found" - }) - total_failed += 1 - continue +@router.get("/tasks") +async def list_tasks( + limit: int = 20, + offset: int = 0, + db: AsyncSession = Depends(get_db) +): + """List background tasks""" + from sqlalchemy import select, desc + + result = await db.execute( + select(Task).order_by(desc(Task.created_at)).limit(limit).offset(offset) + ) + tasks = result.scalars().all() + return tasks - body_results = { - "body_id": body_id, - "body_name": body.name_zh or body.name, - "dates": [] - } - for date_str in request.dates: - try: - # Parse date and set to midnight UTC - target_date = datetime.strptime(date_str, "%Y-%m-%d") - - # Check if data already exists for this date - existing = await position_service.get_positions( - body_id=body_id, - start_time=target_date, - end_time=target_date.replace(hour=23, minute=59, second=59), - session=db - ) - - if existing and len(existing) > 0: - body_results["dates"].append({ - "date": date_str, - "status": "exists", - "message": "Data already exists" - }) - total_success += 1 - continue - - # Download from NASA Horizons - logger.info(f"Downloading data for {body_id} on {date_str}") - positions = horizons_service.get_body_positions( - body_id=body_id, - start_time=target_date, - end_time=target_date, - step="1d" - ) - - if positions and len(positions) > 0: - # Save to database - position_data = [{ - "time": target_date, - "x": positions[0].x, - "y": positions[0].y, - "z": positions[0].z, - "vx": positions[0].vx if hasattr(positions[0], 'vx') else None, - "vy": positions[0].vy if hasattr(positions[0], 'vy') else None, - "vz": positions[0].vz if hasattr(positions[0], 'vz') else None, - }] - - await position_service.save_positions( - body_id=body_id, - positions=position_data, - source="nasa_horizons", - session=db - ) - - body_results["dates"].append({ - "date": date_str, - "status": "success", - "position": { - "x": positions[0].x, - "y": positions[0].y, - "z": positions[0].z - } - }) - total_success += 1 - logger.info(f"✅ Downloaded data for {body_id} on {date_str}") - else: - body_results["dates"].append({ - "date": date_str, - "status": "failed", - "error": "No data returned from NASA" - }) - total_failed += 1 - - except Exception as e: - logger.error(f"Failed to download {body_id} on {date_str}: {e}") - body_results["dates"].append({ - "date": date_str, - "status": "failed", - "error": str(e) - }) - total_failed += 1 - - results.append(body_results) - - logger.info(f"🎉 Download complete: {total_success} succeeded, {total_failed} failed") - return { - "message": f"Downloaded {total_success} positions ({total_failed} failed)", - "total_success": total_success, - "total_failed": total_failed, - "results": results - } - - except Exception as e: - logger.error(f"Download failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) +@router.get("/tasks/{task_id}") +async def get_task_status( + task_id: int, + db: AsyncSession = Depends(get_db) +): + """Get task status""" + # Check Redis first for real-time progress + redis_data = await task_service.get_task_progress_from_redis(task_id) + + # Get DB record + task = await task_service.get_task(db, task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + + # Merge Redis data if available (Redis has fresher progress) + response = { + "id": task.id, + "task_type": task.task_type, + "status": task.status, + "progress": task.progress, + "description": task.description, + "created_at": task.created_at, + "started_at": task.started_at, + "completed_at": task.completed_at, + "error_message": task.error_message, + "result": task.result + } + + if redis_data: + response["status"] = redis_data.get("status", task.status) + response["progress"] = redis_data.get("progress", task.progress) + if "error" in redis_data: + response["error_message"] = redis_data["error"] + + return response diff --git a/backend/app/models/db/__init__.py b/backend/app/models/db/__init__.py index d8b592c..87b7d1d 100644 --- a/backend/app/models/db/__init__.py +++ b/backend/app/models/db/__init__.py @@ -11,6 +11,7 @@ from .user import User, user_roles from .role import Role from .menu import Menu, RoleMenu from .system_settings import SystemSettings +from .task import Task __all__ = [ "CelestialBody", @@ -25,4 +26,5 @@ __all__ = [ "RoleMenu", "SystemSettings", "user_roles", + "Task", ] diff --git a/backend/app/models/db/task.py b/backend/app/models/db/task.py new file mode 100644 index 0000000..50ea790 --- /dev/null +++ b/backend/app/models/db/task.py @@ -0,0 +1,26 @@ +from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, ForeignKey, func +from sqlalchemy.orm import relationship +from app.database import Base + +class Task(Base): + """Background Task Model""" + __tablename__ = "tasks" + + id = Column(Integer, primary_key=True, index=True) + task_type = Column(String(50), nullable=False, comment="Task type (e.g., 'nasa_download')") + status = Column(String(20), nullable=False, default='pending', index=True, comment="pending, running, completed, failed, cancelled") + description = Column(String(255), nullable=True) + params = Column(JSON, nullable=True, comment="Input parameters") + result = Column(JSON, nullable=True, comment="Output results") + progress = Column(Integer, default=0, comment="Progress 0-100") + error_message = Column(Text, nullable=True) + + created_by = Column(Integer, nullable=True, comment="User ID") + + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + started_at = Column(DateTime(timezone=True), nullable=True) + completed_at = Column(DateTime(timezone=True), nullable=True) + + def __repr__(self): + return f"" diff --git a/backend/app/services/nasa_worker.py b/backend/app/services/nasa_worker.py new file mode 100644 index 0000000..3ebf8b6 --- /dev/null +++ b/backend/app/services/nasa_worker.py @@ -0,0 +1,120 @@ +import logging +import asyncio +from datetime import datetime +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.database import AsyncSessionLocal +from app.services.task_service import task_service +from app.services.db_service import celestial_body_service, position_service +from app.services.horizons import horizons_service + +logger = logging.getLogger(__name__) + +async def download_positions_task(task_id: int, body_ids: List[str], dates: List[str]): + """ + Background task worker for downloading NASA positions + """ + logger.info(f"Task {task_id}: Starting download for {len(body_ids)} bodies and {len(dates)} dates") + + async with AsyncSessionLocal() as db: + try: + # Mark as running + await task_service.update_progress(db, task_id, 0, "running") + + total_operations = len(body_ids) * len(dates) + current_op = 0 + success_count = 0 + failed_count = 0 + results = [] + + for body_id in body_ids: + # Check body + body = await celestial_body_service.get_body_by_id(body_id, db) + if not body: + results.append({"body_id": body_id, "error": "Body not found"}) + failed_count += len(dates) + current_op += len(dates) + continue + + body_result = { + "body_id": body_id, + "body_name": body.name, + "dates": [] + } + + for date_str in dates: + try: + target_date = datetime.strptime(date_str, "%Y-%m-%d") + + # Check existing + existing = await position_service.get_positions( + body_id=body_id, + start_time=target_date, + end_time=target_date.replace(hour=23, minute=59, second=59), + session=db + ) + + if existing and len(existing) > 0: + body_result["dates"].append({"date": date_str, "status": "skipped"}) + success_count += 1 + else: + # Download + positions = horizons_service.get_body_positions( + body_id=body_id, + start_time=target_date, + end_time=target_date, + step="1d" + ) + + if positions and len(positions) > 0: + pos_data = [{ + "time": target_date, + "x": positions[0].x, + "y": positions[0].y, + "z": positions[0].z, + "vx": getattr(positions[0], 'vx', None), + "vy": getattr(positions[0], 'vy', None), + "vz": getattr(positions[0], 'vz', None), + }] + await position_service.save_positions( + body_id=body_id, + positions=pos_data, + source="nasa_horizons", + session=db + ) + body_result["dates"].append({"date": date_str, "status": "success"}) + success_count += 1 + else: + body_result["dates"].append({"date": date_str, "status": "failed", "error": "No data"}) + failed_count += 1 + + # Sleep slightly to prevent rate limiting and allow context switching + # await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"Error processing {body_id} on {date_str}: {e}") + body_result["dates"].append({"date": date_str, "status": "error", "error": str(e)}) + failed_count += 1 + + # Update progress + current_op += 1 + progress = int((current_op / total_operations) * 100) + # Only update DB every 5% or so to reduce load, but update Redis frequently + # For now, update every item for simplicity + await task_service.update_progress(db, task_id, progress) + + results.append(body_result) + + # Complete + final_result = { + "total_success": success_count, + "total_failed": failed_count, + "details": results + } + await task_service.complete_task(db, task_id, final_result) + logger.info(f"Task {task_id} completed successfully") + + except Exception as e: + logger.error(f"Task {task_id} failed critically: {e}") + await task_service.fail_task(db, task_id, str(e)) diff --git a/backend/app/services/task_service.py b/backend/app/services/task_service.py new file mode 100644 index 0000000..a731409 --- /dev/null +++ b/backend/app/services/task_service.py @@ -0,0 +1,141 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update +from typing import Optional, Dict, Any +from datetime import datetime +import logging +import asyncio + +from app.models.db import Task +from app.services.redis_cache import redis_cache + +logger = logging.getLogger(__name__) + +class TaskService: + def __init__(self): + self.redis_prefix = "task:progress:" + + async def create_task( + self, + db: AsyncSession, + task_type: str, + description: str = None, + params: Dict[str, Any] = None, + created_by: int = None + ) -> Task: + """Create a new task record""" + task = Task( + task_type=task_type, + description=description, + params=params, + status="pending", + created_by=created_by, + progress=0 + ) + db.add(task) + await db.commit() + await db.refresh(task) + + # Init Redis status + await self._update_redis(task.id, 0, "pending") + + return task + + async def update_progress( + self, + db: AsyncSession, + task_id: int, + progress: int, + status: str = "running" + ): + """Update task progress in DB and Redis""" + # Update DB + stmt = ( + update(Task) + .where(Task.id == task_id) + .values( + progress=progress, + status=status, + started_at=datetime.utcnow() if status == "running" and progress == 0 else None + ) + ) + await db.execute(stmt) + await db.commit() + + # Update Redis for fast polling + await self._update_redis(task_id, progress, status) + + async def complete_task( + self, + db: AsyncSession, + task_id: int, + result: Dict[str, Any] = None + ): + """Mark task as completed""" + stmt = ( + update(Task) + .where(Task.id == task_id) + .values( + status="completed", + progress=100, + completed_at=datetime.utcnow(), + result=result + ) + ) + await db.execute(stmt) + await db.commit() + + await self._update_redis(task_id, 100, "completed") + + async def fail_task( + self, + db: AsyncSession, + task_id: int, + error_message: str + ): + """Mark task as failed""" + stmt = ( + update(Task) + .where(Task.id == task_id) + .values( + status="failed", + completed_at=datetime.utcnow(), + error_message=error_message + ) + ) + await db.execute(stmt) + await db.commit() + + await self._update_redis(task_id, -1, "failed", error=error_message) + + async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]: + """Get task from DB""" + result = await db.execute(select(Task).where(Task.id == task_id)) + return result.scalar_one_or_none() + + async def _update_redis( + self, + task_id: int, + progress: int, + status: str, + error: str = None + ): + """Update transient state in Redis""" + key = f"{self.redis_prefix}{task_id}" + data = { + "id": task_id, + "progress": progress, + "status": status, + "updated_at": datetime.utcnow().isoformat() + } + if error: + data["error"] = error + + # Set TTL for 1 hour + await redis_cache.set(key, data, ttl_seconds=3600) + + async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]: + """Get real-time progress from Redis""" + key = f"{self.redis_prefix}{task_id}" + return await redis_cache.get(key) + +task_service = TaskService() diff --git a/backend/scripts/add_tasks_table.sql b/backend/scripts/add_tasks_table.sql new file mode 100644 index 0000000..558289a --- /dev/null +++ b/backend/scripts/add_tasks_table.sql @@ -0,0 +1,19 @@ +-- Create tasks table for background job management +CREATE TABLE IF NOT EXISTS tasks ( + id SERIAL PRIMARY KEY, + task_type VARCHAR(50) NOT NULL, -- e.g., 'nasa_download' + status VARCHAR(20) NOT NULL DEFAULT 'pending', -- pending, running, completed, failed, cancelled + description VARCHAR(255), + params JSONB, -- Store input parameters (body_ids, dates) + result JSONB, -- Store output results + progress INTEGER DEFAULT 0, -- 0 to 100 + error_message TEXT, + created_by INTEGER, -- User ID who initiated + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP WITH TIME ZONE, + completed_at TIMESTAMP WITH TIME ZONE +); + +CREATE INDEX idx_tasks_status ON tasks(status); +CREATE INDEX idx_tasks_created_at ON tasks(created_at DESC);