feat: Implement background tasks for NASA data download
parent
a91dc00b13
commit
e0cf0ce598
|
|
@ -2,7 +2,7 @@
|
|||
API routes for celestial data
|
||||
"""
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, UploadFile, File, status, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
|
@ -12,7 +12,7 @@ from app.models.celestial import (
|
|||
CelestialDataResponse,
|
||||
BodyInfo,
|
||||
)
|
||||
from app.models.db import Resource
|
||||
from app.models.db import Resource, Task
|
||||
from app.services.horizons import horizons_service
|
||||
from app.services.cache import cache_service
|
||||
from app.services.redis_cache import redis_cache, make_cache_key, get_ttl_seconds
|
||||
|
|
@ -26,6 +26,8 @@ from app.services.db_service import (
|
|||
)
|
||||
from app.services.orbit_service import orbit_service
|
||||
from app.services.system_settings_service import system_settings_service
|
||||
from app.services.task_service import task_service
|
||||
from app.services.nasa_worker import download_positions_task
|
||||
from app.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -1328,136 +1330,83 @@ class DownloadPositionRequest(BaseModel):
|
|||
@router.post("/positions/download")
|
||||
async def download_positions(
|
||||
request: DownloadPositionRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Download position data for specified bodies on specified dates
|
||||
|
||||
This endpoint will:
|
||||
1. Query NASA Horizons API for the position at 00:00:00 UTC on each date
|
||||
2. Save the data to the positions table
|
||||
3. Return the downloaded data
|
||||
|
||||
Args:
|
||||
- body_ids: List of celestial body IDs
|
||||
- dates: List of dates (YYYY-MM-DD format)
|
||||
|
||||
Returns:
|
||||
- Summary of downloaded data with success/failure status
|
||||
Start asynchronous background task to download position data
|
||||
"""
|
||||
logger.info(f"Downloading positions for {len(request.body_ids)} bodies on {len(request.dates)} dates")
|
||||
# Create task record
|
||||
task = await task_service.create_task(
|
||||
db,
|
||||
task_type="nasa_download",
|
||||
description=f"Download positions for {len(request.body_ids)} bodies on {len(request.dates)} dates",
|
||||
params=request.dict(),
|
||||
created_by=None
|
||||
)
|
||||
|
||||
# Add to background tasks
|
||||
background_tasks.add_task(
|
||||
download_positions_task,
|
||||
task.id,
|
||||
request.body_ids,
|
||||
request.dates
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Download task started",
|
||||
"task_id": task.id
|
||||
}
|
||||
|
||||
try:
|
||||
results = []
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
for body_id in request.body_ids:
|
||||
# Check if body exists
|
||||
body = await celestial_body_service.get_body_by_id(body_id, db)
|
||||
if not body:
|
||||
results.append({
|
||||
"body_id": body_id,
|
||||
"status": "failed",
|
||||
"error": "Body not found"
|
||||
})
|
||||
total_failed += 1
|
||||
continue
|
||||
@router.get("/tasks")
|
||||
async def list_tasks(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List background tasks"""
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
result = await db.execute(
|
||||
select(Task).order_by(desc(Task.created_at)).limit(limit).offset(offset)
|
||||
)
|
||||
tasks = result.scalars().all()
|
||||
return tasks
|
||||
|
||||
body_results = {
|
||||
"body_id": body_id,
|
||||
"body_name": body.name_zh or body.name,
|
||||
"dates": []
|
||||
}
|
||||
|
||||
for date_str in request.dates:
|
||||
try:
|
||||
# Parse date and set to midnight UTC
|
||||
target_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
# Check if data already exists for this date
|
||||
existing = await position_service.get_positions(
|
||||
body_id=body_id,
|
||||
start_time=target_date,
|
||||
end_time=target_date.replace(hour=23, minute=59, second=59),
|
||||
session=db
|
||||
)
|
||||
|
||||
if existing and len(existing) > 0:
|
||||
body_results["dates"].append({
|
||||
"date": date_str,
|
||||
"status": "exists",
|
||||
"message": "Data already exists"
|
||||
})
|
||||
total_success += 1
|
||||
continue
|
||||
|
||||
# Download from NASA Horizons
|
||||
logger.info(f"Downloading data for {body_id} on {date_str}")
|
||||
positions = horizons_service.get_body_positions(
|
||||
body_id=body_id,
|
||||
start_time=target_date,
|
||||
end_time=target_date,
|
||||
step="1d"
|
||||
)
|
||||
|
||||
if positions and len(positions) > 0:
|
||||
# Save to database
|
||||
position_data = [{
|
||||
"time": target_date,
|
||||
"x": positions[0].x,
|
||||
"y": positions[0].y,
|
||||
"z": positions[0].z,
|
||||
"vx": positions[0].vx if hasattr(positions[0], 'vx') else None,
|
||||
"vy": positions[0].vy if hasattr(positions[0], 'vy') else None,
|
||||
"vz": positions[0].vz if hasattr(positions[0], 'vz') else None,
|
||||
}]
|
||||
|
||||
await position_service.save_positions(
|
||||
body_id=body_id,
|
||||
positions=position_data,
|
||||
source="nasa_horizons",
|
||||
session=db
|
||||
)
|
||||
|
||||
body_results["dates"].append({
|
||||
"date": date_str,
|
||||
"status": "success",
|
||||
"position": {
|
||||
"x": positions[0].x,
|
||||
"y": positions[0].y,
|
||||
"z": positions[0].z
|
||||
}
|
||||
})
|
||||
total_success += 1
|
||||
logger.info(f"✅ Downloaded data for {body_id} on {date_str}")
|
||||
else:
|
||||
body_results["dates"].append({
|
||||
"date": date_str,
|
||||
"status": "failed",
|
||||
"error": "No data returned from NASA"
|
||||
})
|
||||
total_failed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {body_id} on {date_str}: {e}")
|
||||
body_results["dates"].append({
|
||||
"date": date_str,
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
total_failed += 1
|
||||
|
||||
results.append(body_results)
|
||||
|
||||
logger.info(f"🎉 Download complete: {total_success} succeeded, {total_failed} failed")
|
||||
return {
|
||||
"message": f"Downloaded {total_success} positions ({total_failed} failed)",
|
||||
"total_success": total_success,
|
||||
"total_failed": total_failed,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(
|
||||
task_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get task status"""
|
||||
# Check Redis first for real-time progress
|
||||
redis_data = await task_service.get_task_progress_from_redis(task_id)
|
||||
|
||||
# Get DB record
|
||||
task = await task_service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Merge Redis data if available (Redis has fresher progress)
|
||||
response = {
|
||||
"id": task.id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status,
|
||||
"progress": task.progress,
|
||||
"description": task.description,
|
||||
"created_at": task.created_at,
|
||||
"started_at": task.started_at,
|
||||
"completed_at": task.completed_at,
|
||||
"error_message": task.error_message,
|
||||
"result": task.result
|
||||
}
|
||||
|
||||
if redis_data:
|
||||
response["status"] = redis_data.get("status", task.status)
|
||||
response["progress"] = redis_data.get("progress", task.progress)
|
||||
if "error" in redis_data:
|
||||
response["error_message"] = redis_data["error"]
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from .user import User, user_roles
|
|||
from .role import Role
|
||||
from .menu import Menu, RoleMenu
|
||||
from .system_settings import SystemSettings
|
||||
from .task import Task
|
||||
|
||||
__all__ = [
|
||||
"CelestialBody",
|
||||
|
|
@ -25,4 +26,5 @@ __all__ = [
|
|||
"RoleMenu",
|
||||
"SystemSettings",
|
||||
"user_roles",
|
||||
"Task",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, ForeignKey, func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
|
||||
class Task(Base):
|
||||
"""Background Task Model"""
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_type = Column(String(50), nullable=False, comment="Task type (e.g., 'nasa_download')")
|
||||
status = Column(String(20), nullable=False, default='pending', index=True, comment="pending, running, completed, failed, cancelled")
|
||||
description = Column(String(255), nullable=True)
|
||||
params = Column(JSON, nullable=True, comment="Input parameters")
|
||||
result = Column(JSON, nullable=True, comment="Output results")
|
||||
progress = Column(Integer, default=0, comment="Progress 0-100")
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
created_by = Column(Integer, nullable=True, comment="User ID")
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Task(id={self.id}, type='{self.task_type}', status='{self.status}')>"
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.task_service import task_service
|
||||
from app.services.db_service import celestial_body_service, position_service
|
||||
from app.services.horizons import horizons_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def download_positions_task(task_id: int, body_ids: List[str], dates: List[str]):
|
||||
"""
|
||||
Background task worker for downloading NASA positions
|
||||
"""
|
||||
logger.info(f"Task {task_id}: Starting download for {len(body_ids)} bodies and {len(dates)} dates")
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
# Mark as running
|
||||
await task_service.update_progress(db, task_id, 0, "running")
|
||||
|
||||
total_operations = len(body_ids) * len(dates)
|
||||
current_op = 0
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
results = []
|
||||
|
||||
for body_id in body_ids:
|
||||
# Check body
|
||||
body = await celestial_body_service.get_body_by_id(body_id, db)
|
||||
if not body:
|
||||
results.append({"body_id": body_id, "error": "Body not found"})
|
||||
failed_count += len(dates)
|
||||
current_op += len(dates)
|
||||
continue
|
||||
|
||||
body_result = {
|
||||
"body_id": body_id,
|
||||
"body_name": body.name,
|
||||
"dates": []
|
||||
}
|
||||
|
||||
for date_str in dates:
|
||||
try:
|
||||
target_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
# Check existing
|
||||
existing = await position_service.get_positions(
|
||||
body_id=body_id,
|
||||
start_time=target_date,
|
||||
end_time=target_date.replace(hour=23, minute=59, second=59),
|
||||
session=db
|
||||
)
|
||||
|
||||
if existing and len(existing) > 0:
|
||||
body_result["dates"].append({"date": date_str, "status": "skipped"})
|
||||
success_count += 1
|
||||
else:
|
||||
# Download
|
||||
positions = horizons_service.get_body_positions(
|
||||
body_id=body_id,
|
||||
start_time=target_date,
|
||||
end_time=target_date,
|
||||
step="1d"
|
||||
)
|
||||
|
||||
if positions and len(positions) > 0:
|
||||
pos_data = [{
|
||||
"time": target_date,
|
||||
"x": positions[0].x,
|
||||
"y": positions[0].y,
|
||||
"z": positions[0].z,
|
||||
"vx": getattr(positions[0], 'vx', None),
|
||||
"vy": getattr(positions[0], 'vy', None),
|
||||
"vz": getattr(positions[0], 'vz', None),
|
||||
}]
|
||||
await position_service.save_positions(
|
||||
body_id=body_id,
|
||||
positions=pos_data,
|
||||
source="nasa_horizons",
|
||||
session=db
|
||||
)
|
||||
body_result["dates"].append({"date": date_str, "status": "success"})
|
||||
success_count += 1
|
||||
else:
|
||||
body_result["dates"].append({"date": date_str, "status": "failed", "error": "No data"})
|
||||
failed_count += 1
|
||||
|
||||
# Sleep slightly to prevent rate limiting and allow context switching
|
||||
# await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {body_id} on {date_str}: {e}")
|
||||
body_result["dates"].append({"date": date_str, "status": "error", "error": str(e)})
|
||||
failed_count += 1
|
||||
|
||||
# Update progress
|
||||
current_op += 1
|
||||
progress = int((current_op / total_operations) * 100)
|
||||
# Only update DB every 5% or so to reduce load, but update Redis frequently
|
||||
# For now, update every item for simplicity
|
||||
await task_service.update_progress(db, task_id, progress)
|
||||
|
||||
results.append(body_result)
|
||||
|
||||
# Complete
|
||||
final_result = {
|
||||
"total_success": success_count,
|
||||
"total_failed": failed_count,
|
||||
"details": results
|
||||
}
|
||||
await task_service.complete_task(db, task_id, final_result)
|
||||
logger.info(f"Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} failed critically: {e}")
|
||||
await task_service.fail_task(db, task_id, str(e))
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from app.models.db import Task
|
||||
from app.services.redis_cache import redis_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskService:
|
||||
def __init__(self):
|
||||
self.redis_prefix = "task:progress:"
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_type: str,
|
||||
description: str = None,
|
||||
params: Dict[str, Any] = None,
|
||||
created_by: int = None
|
||||
) -> Task:
|
||||
"""Create a new task record"""
|
||||
task = Task(
|
||||
task_type=task_type,
|
||||
description=description,
|
||||
params=params,
|
||||
status="pending",
|
||||
created_by=created_by,
|
||||
progress=0
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# Init Redis status
|
||||
await self._update_redis(task.id, 0, "pending")
|
||||
|
||||
return task
|
||||
|
||||
async def update_progress(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: int,
|
||||
progress: int,
|
||||
status: str = "running"
|
||||
):
|
||||
"""Update task progress in DB and Redis"""
|
||||
# Update DB
|
||||
stmt = (
|
||||
update(Task)
|
||||
.where(Task.id == task_id)
|
||||
.values(
|
||||
progress=progress,
|
||||
status=status,
|
||||
started_at=datetime.utcnow() if status == "running" and progress == 0 else None
|
||||
)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
# Update Redis for fast polling
|
||||
await self._update_redis(task_id, progress, status)
|
||||
|
||||
async def complete_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: int,
|
||||
result: Dict[str, Any] = None
|
||||
):
|
||||
"""Mark task as completed"""
|
||||
stmt = (
|
||||
update(Task)
|
||||
.where(Task.id == task_id)
|
||||
.values(
|
||||
status="completed",
|
||||
progress=100,
|
||||
completed_at=datetime.utcnow(),
|
||||
result=result
|
||||
)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
await self._update_redis(task_id, 100, "completed")
|
||||
|
||||
async def fail_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: int,
|
||||
error_message: str
|
||||
):
|
||||
"""Mark task as failed"""
|
||||
stmt = (
|
||||
update(Task)
|
||||
.where(Task.id == task_id)
|
||||
.values(
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=error_message
|
||||
)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
await self._update_redis(task_id, -1, "failed", error=error_message)
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]:
|
||||
"""Get task from DB"""
|
||||
result = await db.execute(select(Task).where(Task.id == task_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _update_redis(
|
||||
self,
|
||||
task_id: int,
|
||||
progress: int,
|
||||
status: str,
|
||||
error: str = None
|
||||
):
|
||||
"""Update transient state in Redis"""
|
||||
key = f"{self.redis_prefix}{task_id}"
|
||||
data = {
|
||||
"id": task_id,
|
||||
"progress": progress,
|
||||
"status": status,
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
if error:
|
||||
data["error"] = error
|
||||
|
||||
# Set TTL for 1 hour
|
||||
await redis_cache.set(key, data, ttl_seconds=3600)
|
||||
|
||||
async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]:
|
||||
"""Get real-time progress from Redis"""
|
||||
key = f"{self.redis_prefix}{task_id}"
|
||||
return await redis_cache.get(key)
|
||||
|
||||
task_service = TaskService()
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
-- Create tasks table for background job management
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
task_type VARCHAR(50) NOT NULL, -- e.g., 'nasa_download'
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending', -- pending, running, completed, failed, cancelled
|
||||
description VARCHAR(255),
|
||||
params JSONB, -- Store input parameters (body_ids, dates)
|
||||
result JSONB, -- Store output results
|
||||
progress INTEGER DEFAULT 0, -- 0 to 100
|
||||
error_message TEXT,
|
||||
created_by INTEGER, -- User ID who initiated
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
completed_at TIMESTAMP WITH TIME ZONE
|
||||
);
|
||||
|
||||
CREATE INDEX idx_tasks_status ON tasks(status);
|
||||
CREATE INDEX idx_tasks_created_at ON tasks(created_at DESC);
|
||||
Loading…
Reference in New Issue