653 lines
21 KiB
Python
653 lines
21 KiB
Python
"""
|
|
Database service layer for celestial data operations
|
|
"""
|
|
from typing import List, Optional, Dict, Any
|
|
from datetime import datetime
|
|
from sqlalchemy import select, and_, delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import logging
|
|
|
|
from app.models.db import CelestialBody, Position, StaticData, NasaCache, Resource
|
|
from app.database import AsyncSessionLocal
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CelestialBodyService:
|
|
"""Service for celestial body operations"""
|
|
|
|
@staticmethod
|
|
async def get_all_bodies(
|
|
session: Optional[AsyncSession] = None,
|
|
body_type: Optional[str] = None,
|
|
system_id: Optional[int] = None
|
|
) -> List[CelestialBody]:
|
|
"""
|
|
Get all celestial bodies, optionally filtered by type and star system
|
|
|
|
Args:
|
|
session: Database session
|
|
body_type: Filter by body type (star, planet, dwarf_planet, etc.)
|
|
system_id: Filter by star system ID (1=Solar System, 2+=Exoplanets)
|
|
"""
|
|
async def _query(s: AsyncSession):
|
|
query = select(CelestialBody)
|
|
if body_type:
|
|
query = query.where(CelestialBody.type == body_type)
|
|
if system_id is not None:
|
|
query = query.where(CelestialBody.system_id == system_id)
|
|
result = await s.execute(query.order_by(CelestialBody.name))
|
|
return result.scalars().all()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def get_body_by_id(
|
|
body_id: str,
|
|
session: Optional[AsyncSession] = None
|
|
) -> Optional[CelestialBody]:
|
|
"""Get a celestial body by ID"""
|
|
async def _query(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(CelestialBody).where(CelestialBody.id == body_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def create_body(
|
|
body_data: Dict[str, Any],
|
|
session: Optional[AsyncSession] = None
|
|
) -> CelestialBody:
|
|
"""Create a new celestial body"""
|
|
async def _create(s: AsyncSession):
|
|
body = CelestialBody(**body_data)
|
|
s.add(body)
|
|
await s.commit()
|
|
await s.refresh(body)
|
|
return body
|
|
|
|
if session:
|
|
return await _create(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _create(s)
|
|
|
|
@staticmethod
|
|
async def update_body(
|
|
body_id: str,
|
|
update_data: Dict[str, Any],
|
|
session: Optional[AsyncSession] = None
|
|
) -> Optional[CelestialBody]:
|
|
"""Update a celestial body"""
|
|
async def _update(s: AsyncSession):
|
|
# Query the body
|
|
result = await s.execute(
|
|
select(CelestialBody).where(CelestialBody.id == body_id)
|
|
)
|
|
body = result.scalar_one_or_none()
|
|
|
|
if not body:
|
|
return None
|
|
|
|
# Update fields
|
|
for key, value in update_data.items():
|
|
if hasattr(body, key):
|
|
setattr(body, key, value)
|
|
|
|
await s.commit()
|
|
await s.refresh(body)
|
|
return body
|
|
|
|
if session:
|
|
return await _update(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _update(s)
|
|
|
|
@staticmethod
|
|
async def delete_body(
|
|
body_id: str,
|
|
session: Optional[AsyncSession] = None
|
|
) -> bool:
|
|
"""Delete a celestial body"""
|
|
async def _delete(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(CelestialBody).where(CelestialBody.id == body_id)
|
|
)
|
|
body = result.scalar_one_or_none()
|
|
|
|
if not body:
|
|
return False
|
|
|
|
await s.delete(body)
|
|
await s.commit()
|
|
return True
|
|
|
|
if session:
|
|
return await _delete(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _delete(s)
|
|
|
|
|
|
class PositionService:
|
|
"""Service for position data operations"""
|
|
|
|
@staticmethod
|
|
async def save_positions(
|
|
body_id: str,
|
|
positions: List[Dict[str, Any]],
|
|
source: str = "nasa_horizons",
|
|
session: Optional[AsyncSession] = None
|
|
) -> int:
|
|
"""Save multiple position records for a celestial body (upsert: insert or update if exists)"""
|
|
async def _save(s: AsyncSession):
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
count = 0
|
|
for pos_data in positions:
|
|
# Use PostgreSQL's INSERT ... ON CONFLICT to handle duplicates
|
|
stmt = insert(Position).values(
|
|
body_id=body_id,
|
|
time=pos_data["time"],
|
|
x=pos_data["x"],
|
|
y=pos_data["y"],
|
|
z=pos_data["z"],
|
|
vx=pos_data.get("vx"),
|
|
vy=pos_data.get("vy"),
|
|
vz=pos_data.get("vz"),
|
|
source=source
|
|
)
|
|
|
|
# On conflict (body_id, time), update the existing record
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=['body_id', 'time'],
|
|
set_={
|
|
'x': pos_data["x"],
|
|
'y': pos_data["y"],
|
|
'z': pos_data["z"],
|
|
'vx': pos_data.get("vx"),
|
|
'vy': pos_data.get("vy"),
|
|
'vz': pos_data.get("vz"),
|
|
'source': source
|
|
}
|
|
)
|
|
|
|
await s.execute(stmt)
|
|
count += 1
|
|
|
|
await s.commit()
|
|
return count
|
|
|
|
if session:
|
|
return await _save(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _save(s)
|
|
|
|
@staticmethod
|
|
async def get_positions(
|
|
body_id: str,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[Position]:
|
|
"""Get positions for a celestial body within a time range"""
|
|
async def _query(s: AsyncSession):
|
|
query = select(Position).where(Position.body_id == body_id)
|
|
|
|
if start_time and end_time:
|
|
query = query.where(
|
|
and_(
|
|
Position.time >= start_time,
|
|
Position.time <= end_time
|
|
)
|
|
)
|
|
elif start_time:
|
|
query = query.where(Position.time >= start_time)
|
|
elif end_time:
|
|
query = query.where(Position.time <= end_time)
|
|
|
|
query = query.order_by(Position.time)
|
|
result = await s.execute(query)
|
|
return result.scalars().all()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def get_positions_in_range(
|
|
body_id: str,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[Position]:
|
|
"""Alias for get_positions with required time range"""
|
|
return await PositionService.get_positions(body_id, start_time, end_time, session)
|
|
|
|
@staticmethod
|
|
async def save_position(
|
|
body_id: str,
|
|
time: datetime,
|
|
x: float,
|
|
y: float,
|
|
z: float,
|
|
source: str = "nasa_horizons",
|
|
vx: Optional[float] = None,
|
|
vy: Optional[float] = None,
|
|
vz: Optional[float] = None,
|
|
session: Optional[AsyncSession] = None
|
|
) -> Position:
|
|
"""Save a single position record"""
|
|
async def _save(s: AsyncSession):
|
|
# Check if position already exists
|
|
existing = await s.execute(
|
|
select(Position).where(
|
|
and_(
|
|
Position.body_id == body_id,
|
|
Position.time == time
|
|
)
|
|
)
|
|
)
|
|
existing_pos = existing.scalar_one_or_none()
|
|
|
|
if existing_pos:
|
|
# Update existing position
|
|
existing_pos.x = x
|
|
existing_pos.y = y
|
|
existing_pos.z = z
|
|
existing_pos.vx = vx
|
|
existing_pos.vy = vy
|
|
existing_pos.vz = vz
|
|
existing_pos.source = source
|
|
await s.commit()
|
|
await s.refresh(existing_pos)
|
|
return existing_pos
|
|
else:
|
|
# Create new position
|
|
position = Position(
|
|
body_id=body_id,
|
|
time=time,
|
|
x=x,
|
|
y=y,
|
|
z=z,
|
|
vx=vx,
|
|
vy=vy,
|
|
vz=vz,
|
|
source=source
|
|
)
|
|
s.add(position)
|
|
await s.commit()
|
|
await s.refresh(position)
|
|
return position
|
|
|
|
if session:
|
|
return await _save(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _save(s)
|
|
|
|
@staticmethod
|
|
async def delete_old_positions(
|
|
before_time: datetime,
|
|
session: Optional[AsyncSession] = None
|
|
) -> int:
|
|
"""Delete position records older than specified time"""
|
|
async def _delete(s: AsyncSession):
|
|
result = await s.execute(
|
|
delete(Position).where(Position.time < before_time)
|
|
)
|
|
await s.commit()
|
|
return result.rowcount
|
|
|
|
if session:
|
|
return await _delete(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _delete(s)
|
|
|
|
@staticmethod
|
|
async def get_available_dates(
|
|
body_id: str,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[datetime]:
|
|
"""Get all dates that have position data for a specific body within a time range"""
|
|
async def _query(s: AsyncSession):
|
|
from sqlalchemy import func, Date
|
|
|
|
# Query distinct dates (truncate to date)
|
|
query = select(func.date(Position.time)).where(
|
|
and_(
|
|
Position.body_id == body_id,
|
|
Position.time >= start_time,
|
|
Position.time <= end_time
|
|
)
|
|
).distinct().order_by(func.date(Position.time))
|
|
|
|
result = await s.execute(query)
|
|
dates = [row[0] for row in result]
|
|
return dates
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
|
|
class NasaCacheService:
|
|
"""Service for NASA API response caching"""
|
|
|
|
@staticmethod
|
|
async def get_cached_response(
|
|
body_id: str,
|
|
start_time: Optional[datetime],
|
|
end_time: Optional[datetime],
|
|
step: str,
|
|
session: Optional[AsyncSession] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get cached NASA API response"""
|
|
async def _query(s: AsyncSession):
|
|
# Remove timezone info for comparison with database TIMESTAMP WITHOUT TIME ZONE
|
|
start_naive = start_time.replace(tzinfo=None) if start_time else None
|
|
end_naive = end_time.replace(tzinfo=None) if end_time else None
|
|
now_naive = datetime.utcnow()
|
|
|
|
result = await s.execute(
|
|
select(NasaCache).where(
|
|
and_(
|
|
NasaCache.body_id == body_id,
|
|
NasaCache.start_time == start_naive,
|
|
NasaCache.end_time == end_naive,
|
|
NasaCache.step == step,
|
|
NasaCache.expires_at > now_naive
|
|
)
|
|
)
|
|
)
|
|
cache = result.scalar_one_or_none()
|
|
return cache.data if cache else None
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def save_response(
|
|
body_id: str,
|
|
start_time: Optional[datetime],
|
|
end_time: Optional[datetime],
|
|
step: str,
|
|
response_data: Dict[str, Any],
|
|
ttl_days: int = 7,
|
|
session: Optional[AsyncSession] = None
|
|
) -> NasaCache:
|
|
"""Save NASA API response to cache (upsert: insert or update if exists)"""
|
|
async def _save(s: AsyncSession):
|
|
from datetime import timedelta
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
# Remove timezone info for database storage (TIMESTAMP WITHOUT TIME ZONE)
|
|
start_naive = start_time.replace(tzinfo=None) if start_time else None
|
|
end_naive = end_time.replace(tzinfo=None) if end_time else None
|
|
now_naive = datetime.utcnow()
|
|
|
|
# Generate cache key
|
|
start_str = start_time.isoformat() if start_time else "null"
|
|
end_str = end_time.isoformat() if end_time else "null"
|
|
cache_key = f"{body_id}:{start_str}:{end_str}:{step}"
|
|
|
|
# Use PostgreSQL's INSERT ... ON CONFLICT to handle duplicates atomically
|
|
stmt = insert(NasaCache).values(
|
|
cache_key=cache_key,
|
|
body_id=body_id,
|
|
start_time=start_naive,
|
|
end_time=end_naive,
|
|
step=step,
|
|
data=response_data,
|
|
expires_at=now_naive + timedelta(days=ttl_days)
|
|
)
|
|
|
|
# On conflict, update the existing record
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=['cache_key'],
|
|
set_={
|
|
'data': response_data,
|
|
'created_at': now_naive,
|
|
'expires_at': now_naive + timedelta(days=ttl_days)
|
|
}
|
|
).returning(NasaCache)
|
|
|
|
result = await s.execute(stmt)
|
|
cache = result.scalar_one()
|
|
|
|
await s.commit()
|
|
await s.refresh(cache)
|
|
return cache
|
|
|
|
if session:
|
|
return await _save(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _save(s)
|
|
|
|
|
|
class StaticDataService:
|
|
"""Service for static data operations"""
|
|
|
|
@staticmethod
|
|
async def get_all_items(
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[StaticData]:
|
|
"""Get all static data items"""
|
|
async def _query(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(StaticData).order_by(StaticData.category, StaticData.name)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def create_static(
|
|
data: Dict[str, Any],
|
|
session: Optional[AsyncSession] = None
|
|
) -> StaticData:
|
|
"""Create new static data"""
|
|
async def _create(s: AsyncSession):
|
|
item = StaticData(**data)
|
|
s.add(item)
|
|
await s.commit()
|
|
await s.refresh(item)
|
|
return item
|
|
|
|
if session:
|
|
return await _create(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _create(s)
|
|
|
|
@staticmethod
|
|
async def update_static(
|
|
item_id: int,
|
|
update_data: Dict[str, Any],
|
|
session: Optional[AsyncSession] = None
|
|
) -> Optional[StaticData]:
|
|
"""Update static data"""
|
|
async def _update(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(StaticData).where(StaticData.id == item_id)
|
|
)
|
|
item = result.scalar_one_or_none()
|
|
if not item:
|
|
return None
|
|
|
|
for key, value in update_data.items():
|
|
if hasattr(item, key):
|
|
setattr(item, key, value)
|
|
|
|
await s.commit()
|
|
await s.refresh(item)
|
|
return item
|
|
|
|
if session:
|
|
return await _update(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _update(s)
|
|
|
|
@staticmethod
|
|
async def delete_static(
|
|
item_id: int,
|
|
session: Optional[AsyncSession] = None
|
|
) -> bool:
|
|
"""Delete static data"""
|
|
async def _delete(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(StaticData).where(StaticData.id == item_id)
|
|
)
|
|
item = result.scalar_one_or_none()
|
|
if not item:
|
|
return False
|
|
|
|
await s.delete(item)
|
|
await s.commit()
|
|
return True
|
|
|
|
if session:
|
|
return await _delete(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _delete(s)
|
|
|
|
@staticmethod
|
|
async def get_by_category(
|
|
category: str,
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[StaticData]:
|
|
"""Get all static data items for a category"""
|
|
async def _query(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(StaticData)
|
|
.where(StaticData.category == category)
|
|
.order_by(StaticData.name)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def get_all_categories(
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[str]:
|
|
"""Get all available categories"""
|
|
async def _query(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(StaticData.category).distinct()
|
|
)
|
|
return [row[0] for row in result]
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
|
|
class ResourceService:
|
|
"""Service for resource file management"""
|
|
|
|
@staticmethod
|
|
async def create_resource(
|
|
resource_data: Dict[str, Any],
|
|
session: Optional[AsyncSession] = None
|
|
) -> Resource:
|
|
"""Create a new resource record"""
|
|
async def _create(s: AsyncSession):
|
|
resource = Resource(**resource_data)
|
|
s.add(resource)
|
|
await s.commit()
|
|
await s.refresh(resource)
|
|
return resource
|
|
|
|
if session:
|
|
return await _create(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _create(s)
|
|
|
|
@staticmethod
|
|
async def get_resources_by_body(
|
|
body_id: str,
|
|
resource_type: Optional[str] = None,
|
|
session: Optional[AsyncSession] = None
|
|
) -> List[Resource]:
|
|
"""Get all resources for a celestial body"""
|
|
async def _query(s: AsyncSession):
|
|
query = select(Resource).where(Resource.body_id == body_id)
|
|
if resource_type:
|
|
query = query.where(Resource.resource_type == resource_type)
|
|
result = await s.execute(query.order_by(Resource.created_at))
|
|
return result.scalars().all()
|
|
|
|
if session:
|
|
return await _query(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _query(s)
|
|
|
|
@staticmethod
|
|
async def delete_resource(
|
|
resource_id: int,
|
|
session: Optional[AsyncSession] = None
|
|
) -> bool:
|
|
"""Delete a resource record"""
|
|
async def _delete(s: AsyncSession):
|
|
result = await s.execute(
|
|
select(Resource).where(Resource.id == resource_id)
|
|
)
|
|
resource = result.scalar_one_or_none()
|
|
if resource:
|
|
await s.delete(resource)
|
|
await s.commit()
|
|
return True
|
|
return False
|
|
|
|
if session:
|
|
return await _delete(session)
|
|
else:
|
|
async with AsyncSessionLocal() as s:
|
|
return await _delete(s)
|
|
|
|
|
|
# Export service instances
|
|
celestial_body_service = CelestialBodyService()
|
|
position_service = PositionService()
|
|
nasa_cache_service = NasaCacheService()
|
|
static_data_service = StaticDataService()
|
|
resource_service = ResourceService()
|