cosmo/backend/app/services/db_service.py

690 lines
22 KiB
Python

"""
Database service layer for celestial data operations
"""
from typing import List, Optional, Dict, Any
from datetime import datetime
from sqlalchemy import select, and_, delete
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from app.models.db import CelestialBody, Position, StaticData, NasaCache, Resource
from app.database import AsyncSessionLocal
logger = logging.getLogger(__name__)
class CelestialBodyService:
"""Service for celestial body operations"""
@staticmethod
async def get_all_bodies(
session: Optional[AsyncSession] = None,
body_type: Optional[str] = None,
system_id: Optional[int] = None
) -> List[CelestialBody]:
"""
Get all celestial bodies, optionally filtered by type and star system
Args:
session: Database session
body_type: Filter by body type (star, planet, dwarf_planet, etc.)
system_id: Filter by star system ID (1=Solar System, 2+=Exoplanets)
"""
async def _query(s: AsyncSession):
query = select(CelestialBody)
if body_type:
query = query.where(CelestialBody.type == body_type)
if system_id is not None:
query = query.where(CelestialBody.system_id == system_id)
result = await s.execute(query.order_by(CelestialBody.name))
return result.scalars().all()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def get_body_by_id(
body_id: str,
session: Optional[AsyncSession] = None
) -> Optional[CelestialBody]:
"""Get a celestial body by ID"""
async def _query(s: AsyncSession):
result = await s.execute(
select(CelestialBody).where(CelestialBody.id == body_id)
)
return result.scalar_one_or_none()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def create_body(
body_data: Dict[str, Any],
session: Optional[AsyncSession] = None
) -> CelestialBody:
"""Create a new celestial body"""
async def _create(s: AsyncSession):
body = CelestialBody(**body_data)
s.add(body)
await s.commit()
await s.refresh(body)
return body
if session:
return await _create(session)
else:
async with AsyncSessionLocal() as s:
return await _create(s)
@staticmethod
async def update_body(
body_id: str,
update_data: Dict[str, Any],
session: Optional[AsyncSession] = None
) -> Optional[CelestialBody]:
"""Update a celestial body"""
async def _update(s: AsyncSession):
# Query the body
result = await s.execute(
select(CelestialBody).where(CelestialBody.id == body_id)
)
body = result.scalar_one_or_none()
if not body:
return None
# Update fields
for key, value in update_data.items():
if hasattr(body, key):
setattr(body, key, value)
await s.commit()
await s.refresh(body)
return body
if session:
return await _update(session)
else:
async with AsyncSessionLocal() as s:
return await _update(s)
@staticmethod
async def delete_body(
body_id: str,
session: Optional[AsyncSession] = None
) -> bool:
"""Delete a celestial body"""
async def _delete(s: AsyncSession):
result = await s.execute(
select(CelestialBody).where(CelestialBody.id == body_id)
)
body = result.scalar_one_or_none()
if not body:
return False
await s.delete(body)
await s.commit()
return True
if session:
return await _delete(session)
else:
async with AsyncSessionLocal() as s:
return await _delete(s)
class PositionService:
"""Service for position data operations"""
@staticmethod
async def save_positions(
body_id: str,
positions: List[Dict[str, Any]],
source: str = "nasa_horizons",
session: Optional[AsyncSession] = None
) -> int:
"""Save multiple position records for a celestial body (upsert: insert or update if exists)"""
async def _save(s: AsyncSession):
from sqlalchemy.dialects.postgresql import insert
count = 0
for pos_data in positions:
# Use PostgreSQL's INSERT ... ON CONFLICT to handle duplicates
stmt = insert(Position).values(
body_id=body_id,
time=pos_data["time"],
x=pos_data["x"],
y=pos_data["y"],
z=pos_data["z"],
vx=pos_data.get("vx"),
vy=pos_data.get("vy"),
vz=pos_data.get("vz"),
source=source
)
# On conflict (body_id, time), update the existing record
stmt = stmt.on_conflict_do_update(
index_elements=['body_id', 'time'],
set_={
'x': pos_data["x"],
'y': pos_data["y"],
'z': pos_data["z"],
'vx': pos_data.get("vx"),
'vy': pos_data.get("vy"),
'vz': pos_data.get("vz"),
'source': source
}
)
await s.execute(stmt)
count += 1
await s.commit()
return count
if session:
return await _save(session)
else:
async with AsyncSessionLocal() as s:
return await _save(s)
@staticmethod
async def get_positions(
body_id: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
session: Optional[AsyncSession] = None
) -> List[Position]:
"""Get positions for a celestial body within a time range"""
async def _query(s: AsyncSession):
query = select(Position).where(Position.body_id == body_id)
if start_time and end_time:
query = query.where(
and_(
Position.time >= start_time,
Position.time <= end_time
)
)
elif start_time:
query = query.where(Position.time >= start_time)
elif end_time:
query = query.where(Position.time <= end_time)
query = query.order_by(Position.time)
result = await s.execute(query)
return result.scalars().all()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def get_positions_in_range(
body_id: str,
start_time: datetime,
end_time: datetime,
session: Optional[AsyncSession] = None
) -> List[Position]:
"""Alias for get_positions with required time range"""
return await PositionService.get_positions(body_id, start_time, end_time, session)
@staticmethod
async def save_position(
body_id: str,
time: datetime,
x: float,
y: float,
z: float,
source: str = "nasa_horizons",
vx: Optional[float] = None,
vy: Optional[float] = None,
vz: Optional[float] = None,
session: Optional[AsyncSession] = None
) -> Position:
"""Save a single position record"""
async def _save(s: AsyncSession):
# Check if position already exists
existing = await s.execute(
select(Position).where(
and_(
Position.body_id == body_id,
Position.time == time
)
)
)
existing_pos = existing.scalar_one_or_none()
if existing_pos:
# Update existing position
existing_pos.x = x
existing_pos.y = y
existing_pos.z = z
existing_pos.vx = vx
existing_pos.vy = vy
existing_pos.vz = vz
existing_pos.source = source
await s.commit()
await s.refresh(existing_pos)
return existing_pos
else:
# Create new position
position = Position(
body_id=body_id,
time=time,
x=x,
y=y,
z=z,
vx=vx,
vy=vy,
vz=vz,
source=source
)
s.add(position)
await s.commit()
await s.refresh(position)
return position
if session:
return await _save(session)
else:
async with AsyncSessionLocal() as s:
return await _save(s)
@staticmethod
async def delete_old_positions(
before_time: datetime,
session: Optional[AsyncSession] = None
) -> int:
"""Delete position records older than specified time"""
async def _delete(s: AsyncSession):
result = await s.execute(
delete(Position).where(Position.time < before_time)
)
await s.commit()
return result.rowcount
if session:
return await _delete(session)
else:
async with AsyncSessionLocal() as s:
return await _delete(s)
@staticmethod
async def get_available_dates(
body_id: str,
start_time: datetime,
end_time: datetime,
session: Optional[AsyncSession] = None
) -> List[datetime]:
"""Get all dates that have position data for a specific body within a time range"""
async def _query(s: AsyncSession):
from sqlalchemy import func, Date
# Query distinct dates (truncate to date)
query = select(func.date(Position.time)).where(
and_(
Position.body_id == body_id,
Position.time >= start_time,
Position.time <= end_time
)
).distinct().order_by(func.date(Position.time))
result = await s.execute(query)
dates = [row[0] for row in result]
return dates
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
class NasaCacheService:
"""Service for NASA API response caching"""
@staticmethod
async def get_cached_response(
body_id: str,
start_time: Optional[datetime],
end_time: Optional[datetime],
step: str,
session: Optional[AsyncSession] = None
) -> Optional[Dict[str, Any]]:
"""Get cached NASA API response"""
async def _query(s: AsyncSession):
# Remove timezone info for comparison with database TIMESTAMP WITHOUT TIME ZONE
start_naive = start_time.replace(tzinfo=None) if start_time else None
end_naive = end_time.replace(tzinfo=None) if end_time else None
now_naive = datetime.utcnow()
result = await s.execute(
select(NasaCache).where(
and_(
NasaCache.body_id == body_id,
NasaCache.start_time == start_naive,
NasaCache.end_time == end_naive,
NasaCache.step == step,
NasaCache.expires_at > now_naive
)
)
)
cache = result.scalar_one_or_none()
return cache.data if cache else None
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def save_response(
body_id: str,
start_time: Optional[datetime],
end_time: Optional[datetime],
step: str,
response_data: Dict[str, Any],
ttl_days: int = 7,
session: Optional[AsyncSession] = None
) -> NasaCache:
"""Save NASA API response to cache (upsert: insert or update if exists)"""
async def _save(s: AsyncSession):
from datetime import timedelta
from sqlalchemy.dialects.postgresql import insert
# Remove timezone info for database storage (TIMESTAMP WITHOUT TIME ZONE)
start_naive = start_time.replace(tzinfo=None) if start_time else None
end_naive = end_time.replace(tzinfo=None) if end_time else None
now_naive = datetime.utcnow()
# Generate cache key
start_str = start_time.isoformat() if start_time else "null"
end_str = end_time.isoformat() if end_time else "null"
cache_key = f"{body_id}:{start_str}:{end_str}:{step}"
# Use PostgreSQL's INSERT ... ON CONFLICT to handle duplicates atomically
stmt = insert(NasaCache).values(
cache_key=cache_key,
body_id=body_id,
start_time=start_naive,
end_time=end_naive,
step=step,
data=response_data,
expires_at=now_naive + timedelta(days=ttl_days)
)
# On conflict, update the existing record
stmt = stmt.on_conflict_do_update(
index_elements=['cache_key'],
set_={
'data': response_data,
'created_at': now_naive,
'expires_at': now_naive + timedelta(days=ttl_days)
}
).returning(NasaCache)
result = await s.execute(stmt)
cache = result.scalar_one()
await s.commit()
await s.refresh(cache)
return cache
if session:
return await _save(session)
else:
async with AsyncSessionLocal() as s:
return await _save(s)
class StaticDataService:
"""Service for static data operations"""
@staticmethod
async def get_all_items(
session: Optional[AsyncSession] = None
) -> List[StaticData]:
"""Get all static data items"""
async def _query(s: AsyncSession):
result = await s.execute(
select(StaticData).order_by(StaticData.category, StaticData.name)
)
return result.scalars().all()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def create_static(
data: Dict[str, Any],
session: Optional[AsyncSession] = None
) -> StaticData:
"""Create new static data"""
async def _create(s: AsyncSession):
item = StaticData(**data)
s.add(item)
await s.commit()
await s.refresh(item)
return item
if session:
return await _create(session)
else:
async with AsyncSessionLocal() as s:
return await _create(s)
@staticmethod
async def update_static(
item_id: int,
update_data: Dict[str, Any],
session: Optional[AsyncSession] = None
) -> Optional[StaticData]:
"""Update static data"""
async def _update(s: AsyncSession):
result = await s.execute(
select(StaticData).where(StaticData.id == item_id)
)
item = result.scalar_one_or_none()
if not item:
return None
for key, value in update_data.items():
if hasattr(item, key):
setattr(item, key, value)
await s.commit()
await s.refresh(item)
return item
if session:
return await _update(session)
else:
async with AsyncSessionLocal() as s:
return await _update(s)
@staticmethod
async def delete_static(
item_id: int,
session: Optional[AsyncSession] = None
) -> bool:
"""Delete static data"""
async def _delete(s: AsyncSession):
result = await s.execute(
select(StaticData).where(StaticData.id == item_id)
)
item = result.scalar_one_or_none()
if not item:
return False
await s.delete(item)
await s.commit()
return True
if session:
return await _delete(session)
else:
async with AsyncSessionLocal() as s:
return await _delete(s)
@staticmethod
async def get_by_category(
category: str,
session: Optional[AsyncSession] = None
) -> List[StaticData]:
"""Get all static data items for a category"""
async def _query(s: AsyncSession):
result = await s.execute(
select(StaticData)
.where(StaticData.category == category)
.order_by(StaticData.name)
)
return result.scalars().all()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def get_all_categories(
session: Optional[AsyncSession] = None
) -> List[str]:
"""Get all available categories"""
async def _query(s: AsyncSession):
result = await s.execute(
select(StaticData.category).distinct()
)
return [row[0] for row in result]
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
class ResourceService:
"""Service for resource file management"""
@staticmethod
async def create_resource(
resource_data: Dict[str, Any],
session: Optional[AsyncSession] = None
) -> Resource:
"""Create a new resource record"""
async def _create(s: AsyncSession):
resource = Resource(**resource_data)
s.add(resource)
await s.commit()
await s.refresh(resource)
return resource
if session:
return await _create(session)
else:
async with AsyncSessionLocal() as s:
return await _create(s)
@staticmethod
async def get_resources_by_body(
body_id: str,
resource_type: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> List[Resource]:
"""Get all resources for a celestial body"""
async def _query(s: AsyncSession):
query = select(Resource).where(Resource.body_id == body_id)
if resource_type:
query = query.where(Resource.resource_type == resource_type)
result = await s.execute(query.order_by(Resource.created_at))
return result.scalars().all()
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def get_all_resources_grouped_by_body(
body_ids: Optional[List[str]] = None,
session: Optional[AsyncSession] = None
) -> Dict[str, List[Resource]]:
"""
Get all resources grouped by body_id (optimized for bulk loading)
Args:
body_ids: Optional list of body IDs to filter by
session: Database session
Returns:
Dictionary mapping body_id to list of resources
"""
async def _query(s: AsyncSession):
query = select(Resource).order_by(Resource.body_id, Resource.created_at)
if body_ids:
query = query.where(Resource.body_id.in_(body_ids))
result = await s.execute(query)
resources = result.scalars().all()
# Group by body_id
grouped = {}
for resource in resources:
if resource.body_id not in grouped:
grouped[resource.body_id] = []
grouped[resource.body_id].append(resource)
return grouped
if session:
return await _query(session)
else:
async with AsyncSessionLocal() as s:
return await _query(s)
@staticmethod
async def delete_resource(
resource_id: int,
session: Optional[AsyncSession] = None
) -> bool:
"""Delete a resource record"""
async def _delete(s: AsyncSession):
result = await s.execute(
select(Resource).where(Resource.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource:
await s.delete(resource)
await s.commit()
return True
return False
if session:
return await _delete(session)
else:
async with AsyncSessionLocal() as s:
return await _delete(s)
# Export service instances
celestial_body_service = CelestialBodyService()
position_service = PositionService()
nasa_cache_service = NasaCacheService()
static_data_service = StaticDataService()
resource_service = ResourceService()