feat: Implement background tasks for NASA data download

main
mula.liu 2025-11-30 23:58:37 +08:00
parent a91dc00b13
commit e0cf0ce598
6 changed files with 384 additions and 127 deletions

View File

@ -2,7 +2,7 @@
API routes for celestial data
"""
from datetime import datetime
from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status
from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional, Dict, Any
import logging
@ -12,7 +12,7 @@ from app.models.celestial import (
CelestialDataResponse,
BodyInfo,
)
from app.models.db import Resource
from app.models.db import Resource, Task
from app.services.horizons import horizons_service
from app.services.cache import cache_service
from app.services.redis_cache import redis_cache, make_cache_key, get_ttl_seconds
@ -26,6 +26,8 @@ from app.services.db_service import (
)
from app.services.orbit_service import orbit_service
from app.services.system_settings_service import system_settings_service
from app.services.task_service import task_service
from app.services.nasa_worker import download_positions_task
from app.database import get_db
logger = logging.getLogger(__name__)
@ -1328,136 +1330,83 @@ class DownloadPositionRequest(BaseModel):
@router.post("/positions/download")
async def download_positions(
request: DownloadPositionRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
Download position data for specified bodies on specified dates
This endpoint will:
1. Query NASA Horizons API for the position at 00:00:00 UTC on each date
2. Save the data to the positions table
3. Return the downloaded data
Args:
- body_ids: List of celestial body IDs
- dates: List of dates (YYYY-MM-DD format)
Returns:
- Summary of downloaded data with success/failure status
Start asynchronous background task to download position data
"""
logger.info(f"Downloading positions for {len(request.body_ids)} bodies on {len(request.dates)} dates")
# Create task record
task = await task_service.create_task(
db,
task_type="nasa_download",
description=f"Download positions for {len(request.body_ids)} bodies on {len(request.dates)} dates",
params=request.dict(),
created_by=None
)
# Add to background tasks
background_tasks.add_task(
download_positions_task,
task.id,
request.body_ids,
request.dates
)
return {
"message": "Download task started",
"task_id": task.id
}
try:
results = []
total_success = 0
total_failed = 0
for body_id in request.body_ids:
# Check if body exists
body = await celestial_body_service.get_body_by_id(body_id, db)
if not body:
results.append({
"body_id": body_id,
"status": "failed",
"error": "Body not found"
})
total_failed += 1
continue
@router.get("/tasks")
async def list_tasks(
limit: int = 20,
offset: int = 0,
db: AsyncSession = Depends(get_db)
):
"""List background tasks"""
from sqlalchemy import select, desc
result = await db.execute(
select(Task).order_by(desc(Task.created_at)).limit(limit).offset(offset)
)
tasks = result.scalars().all()
return tasks
body_results = {
"body_id": body_id,
"body_name": body.name_zh or body.name,
"dates": []
}
for date_str in request.dates:
try:
# Parse date and set to midnight UTC
target_date = datetime.strptime(date_str, "%Y-%m-%d")
# Check if data already exists for this date
existing = await position_service.get_positions(
body_id=body_id,
start_time=target_date,
end_time=target_date.replace(hour=23, minute=59, second=59),
session=db
)
if existing and len(existing) > 0:
body_results["dates"].append({
"date": date_str,
"status": "exists",
"message": "Data already exists"
})
total_success += 1
continue
# Download from NASA Horizons
logger.info(f"Downloading data for {body_id} on {date_str}")
positions = horizons_service.get_body_positions(
body_id=body_id,
start_time=target_date,
end_time=target_date,
step="1d"
)
if positions and len(positions) > 0:
# Save to database
position_data = [{
"time": target_date,
"x": positions[0].x,
"y": positions[0].y,
"z": positions[0].z,
"vx": positions[0].vx if hasattr(positions[0], 'vx') else None,
"vy": positions[0].vy if hasattr(positions[0], 'vy') else None,
"vz": positions[0].vz if hasattr(positions[0], 'vz') else None,
}]
await position_service.save_positions(
body_id=body_id,
positions=position_data,
source="nasa_horizons",
session=db
)
body_results["dates"].append({
"date": date_str,
"status": "success",
"position": {
"x": positions[0].x,
"y": positions[0].y,
"z": positions[0].z
}
})
total_success += 1
logger.info(f"✅ Downloaded data for {body_id} on {date_str}")
else:
body_results["dates"].append({
"date": date_str,
"status": "failed",
"error": "No data returned from NASA"
})
total_failed += 1
except Exception as e:
logger.error(f"Failed to download {body_id} on {date_str}: {e}")
body_results["dates"].append({
"date": date_str,
"status": "failed",
"error": str(e)
})
total_failed += 1
results.append(body_results)
logger.info(f"🎉 Download complete: {total_success} succeeded, {total_failed} failed")
return {
"message": f"Downloaded {total_success} positions ({total_failed} failed)",
"total_success": total_success,
"total_failed": total_failed,
"results": results
}
except Exception as e:
logger.error(f"Download failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/tasks/{task_id}")
async def get_task_status(
task_id: int,
db: AsyncSession = Depends(get_db)
):
"""Get task status"""
# Check Redis first for real-time progress
redis_data = await task_service.get_task_progress_from_redis(task_id)
# Get DB record
task = await task_service.get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Merge Redis data if available (Redis has fresher progress)
response = {
"id": task.id,
"task_type": task.task_type,
"status": task.status,
"progress": task.progress,
"description": task.description,
"created_at": task.created_at,
"started_at": task.started_at,
"completed_at": task.completed_at,
"error_message": task.error_message,
"result": task.result
}
if redis_data:
response["status"] = redis_data.get("status", task.status)
response["progress"] = redis_data.get("progress", task.progress)
if "error" in redis_data:
response["error_message"] = redis_data["error"]
return response

View File

@ -11,6 +11,7 @@ from .user import User, user_roles
from .role import Role
from .menu import Menu, RoleMenu
from .system_settings import SystemSettings
from .task import Task
__all__ = [
"CelestialBody",
@ -25,4 +26,5 @@ __all__ = [
"RoleMenu",
"SystemSettings",
"user_roles",
"Task",
]

View File

@ -0,0 +1,26 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, ForeignKey, func
from sqlalchemy.orm import relationship
from app.database import Base
class Task(Base):
"""Background Task Model"""
__tablename__ = "tasks"
id = Column(Integer, primary_key=True, index=True)
task_type = Column(String(50), nullable=False, comment="Task type (e.g., 'nasa_download')")
status = Column(String(20), nullable=False, default='pending', index=True, comment="pending, running, completed, failed, cancelled")
description = Column(String(255), nullable=True)
params = Column(JSON, nullable=True, comment="Input parameters")
result = Column(JSON, nullable=True, comment="Output results")
progress = Column(Integer, default=0, comment="Progress 0-100")
error_message = Column(Text, nullable=True)
created_by = Column(Integer, nullable=True, comment="User ID")
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
def __repr__(self):
return f"<Task(id={self.id}, type='{self.task_type}', status='{self.status}')>"

View File

@ -0,0 +1,120 @@
import logging
import asyncio
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.database import AsyncSessionLocal
from app.services.task_service import task_service
from app.services.db_service import celestial_body_service, position_service
from app.services.horizons import horizons_service
logger = logging.getLogger(__name__)
async def download_positions_task(task_id: int, body_ids: List[str], dates: List[str]):
"""
Background task worker for downloading NASA positions
"""
logger.info(f"Task {task_id}: Starting download for {len(body_ids)} bodies and {len(dates)} dates")
async with AsyncSessionLocal() as db:
try:
# Mark as running
await task_service.update_progress(db, task_id, 0, "running")
total_operations = len(body_ids) * len(dates)
current_op = 0
success_count = 0
failed_count = 0
results = []
for body_id in body_ids:
# Check body
body = await celestial_body_service.get_body_by_id(body_id, db)
if not body:
results.append({"body_id": body_id, "error": "Body not found"})
failed_count += len(dates)
current_op += len(dates)
continue
body_result = {
"body_id": body_id,
"body_name": body.name,
"dates": []
}
for date_str in dates:
try:
target_date = datetime.strptime(date_str, "%Y-%m-%d")
# Check existing
existing = await position_service.get_positions(
body_id=body_id,
start_time=target_date,
end_time=target_date.replace(hour=23, minute=59, second=59),
session=db
)
if existing and len(existing) > 0:
body_result["dates"].append({"date": date_str, "status": "skipped"})
success_count += 1
else:
# Download
positions = horizons_service.get_body_positions(
body_id=body_id,
start_time=target_date,
end_time=target_date,
step="1d"
)
if positions and len(positions) > 0:
pos_data = [{
"time": target_date,
"x": positions[0].x,
"y": positions[0].y,
"z": positions[0].z,
"vx": getattr(positions[0], 'vx', None),
"vy": getattr(positions[0], 'vy', None),
"vz": getattr(positions[0], 'vz', None),
}]
await position_service.save_positions(
body_id=body_id,
positions=pos_data,
source="nasa_horizons",
session=db
)
body_result["dates"].append({"date": date_str, "status": "success"})
success_count += 1
else:
body_result["dates"].append({"date": date_str, "status": "failed", "error": "No data"})
failed_count += 1
# Sleep slightly to prevent rate limiting and allow context switching
# await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error processing {body_id} on {date_str}: {e}")
body_result["dates"].append({"date": date_str, "status": "error", "error": str(e)})
failed_count += 1
# Update progress
current_op += 1
progress = int((current_op / total_operations) * 100)
# Only update DB every 5% or so to reduce load, but update Redis frequently
# For now, update every item for simplicity
await task_service.update_progress(db, task_id, progress)
results.append(body_result)
# Complete
final_result = {
"total_success": success_count,
"total_failed": failed_count,
"details": results
}
await task_service.complete_task(db, task_id, final_result)
logger.info(f"Task {task_id} completed successfully")
except Exception as e:
logger.error(f"Task {task_id} failed critically: {e}")
await task_service.fail_task(db, task_id, str(e))

View File

@ -0,0 +1,141 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from typing import Optional, Dict, Any
from datetime import datetime
import logging
import asyncio
from app.models.db import Task
from app.services.redis_cache import redis_cache
logger = logging.getLogger(__name__)
class TaskService:
def __init__(self):
self.redis_prefix = "task:progress:"
async def create_task(
self,
db: AsyncSession,
task_type: str,
description: str = None,
params: Dict[str, Any] = None,
created_by: int = None
) -> Task:
"""Create a new task record"""
task = Task(
task_type=task_type,
description=description,
params=params,
status="pending",
created_by=created_by,
progress=0
)
db.add(task)
await db.commit()
await db.refresh(task)
# Init Redis status
await self._update_redis(task.id, 0, "pending")
return task
async def update_progress(
self,
db: AsyncSession,
task_id: int,
progress: int,
status: str = "running"
):
"""Update task progress in DB and Redis"""
# Update DB
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
progress=progress,
status=status,
started_at=datetime.utcnow() if status == "running" and progress == 0 else None
)
)
await db.execute(stmt)
await db.commit()
# Update Redis for fast polling
await self._update_redis(task_id, progress, status)
async def complete_task(
self,
db: AsyncSession,
task_id: int,
result: Dict[str, Any] = None
):
"""Mark task as completed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="completed",
progress=100,
completed_at=datetime.utcnow(),
result=result
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, 100, "completed")
async def fail_task(
self,
db: AsyncSession,
task_id: int,
error_message: str
):
"""Mark task as failed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="failed",
completed_at=datetime.utcnow(),
error_message=error_message
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, -1, "failed", error=error_message)
async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]:
"""Get task from DB"""
result = await db.execute(select(Task).where(Task.id == task_id))
return result.scalar_one_or_none()
async def _update_redis(
self,
task_id: int,
progress: int,
status: str,
error: str = None
):
"""Update transient state in Redis"""
key = f"{self.redis_prefix}{task_id}"
data = {
"id": task_id,
"progress": progress,
"status": status,
"updated_at": datetime.utcnow().isoformat()
}
if error:
data["error"] = error
# Set TTL for 1 hour
await redis_cache.set(key, data, ttl_seconds=3600)
async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]:
"""Get real-time progress from Redis"""
key = f"{self.redis_prefix}{task_id}"
return await redis_cache.get(key)
task_service = TaskService()

View File

@ -0,0 +1,19 @@
-- Create tasks table for background job management
CREATE TABLE IF NOT EXISTS tasks (
id SERIAL PRIMARY KEY,
task_type VARCHAR(50) NOT NULL, -- e.g., 'nasa_download'
status VARCHAR(20) NOT NULL DEFAULT 'pending', -- pending, running, completed, failed, cancelled
description VARCHAR(255),
params JSONB, -- Store input parameters (body_ids, dates)
result JSONB, -- Store output results
progress INTEGER DEFAULT 0, -- 0 to 100
error_message TEXT,
created_by INTEGER, -- User ID who initiated
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE
);
CREATE INDEX idx_tasks_status ON tasks(status);
CREATE INDEX idx_tasks_created_at ON tasks(created_at DESC);