mirror of
https://github.com/zvx-echo6/central.git
synced 2026-06-10 03:44:39 +02:00
chore: normalize line endings to LF
This commit is contained in:
parent
43088d7fbb
commit
374a8c067f
26 changed files with 5357 additions and 5346 deletions
|
|
@ -1,430 +1,430 @@
|
|||
"""FIRMS (Fire Information for Resource Management System) adapter."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from central.adapter import SourceAdapter
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.config_store import ConfigStore
|
||||
from central.models import Event, Geo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FIRMS API base URL
|
||||
FIRMS_API_BASE = "https://firms.modaps.eosdis.nasa.gov/api/area/csv"
|
||||
|
||||
# Satellite name mapping
|
||||
SATELLITE_SHORT = {
|
||||
"VIIRS_SNPP_NRT": "viirs_snpp",
|
||||
"VIIRS_NOAA20_NRT": "viirs_noaa20",
|
||||
"VIIRS_NOAA21_NRT": "viirs_noaa21",
|
||||
}
|
||||
|
||||
# Confidence mapping
|
||||
CONFIDENCE_MAP = {
|
||||
"l": "low",
|
||||
"n": "nominal",
|
||||
"h": "high",
|
||||
}
|
||||
|
||||
# Severity mapping (confidence -> severity level)
|
||||
SEVERITY_MAP = {
|
||||
"high": 3,
|
||||
"nominal": 2,
|
||||
"low": 1,
|
||||
}
|
||||
|
||||
|
||||
class FIRMSAdapter(SourceAdapter):
|
||||
"""NASA FIRMS fire hotspot adapter."""
|
||||
|
||||
name = "firms"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AdapterConfig,
|
||||
config_store: ConfigStore,
|
||||
cursor_db_path: Path,
|
||||
) -> None:
|
||||
self._config_store = config_store
|
||||
self._cursor_db_path = cursor_db_path
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._db: sqlite3.Connection | None = None
|
||||
self._api_key: str | None = None
|
||||
|
||||
# Extract settings from config
|
||||
self._api_key_alias: str = config.settings.get("api_key_alias", "firms")
|
||||
self._satellites: list[str] = config.settings.get(
|
||||
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
|
||||
)
|
||||
|
||||
# Parse region from settings
|
||||
region_dict = config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region: RegionConfig | None = RegionConfig(**region_dict)
|
||||
else:
|
||||
self.region = None
|
||||
|
||||
async def apply_config(self, new_config: AdapterConfig) -> None:
|
||||
"""Apply new configuration from hot-reload."""
|
||||
old_alias = self._api_key_alias
|
||||
|
||||
# Update settings
|
||||
self._api_key_alias = new_config.settings.get("api_key_alias", "firms")
|
||||
self._satellites = new_config.settings.get(
|
||||
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
|
||||
)
|
||||
|
||||
# Update region
|
||||
region_dict = new_config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region = RegionConfig(**region_dict)
|
||||
else:
|
||||
self.region = None
|
||||
|
||||
# If API key alias changed, re-fetch the key
|
||||
if self._api_key_alias != old_alias:
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if self._api_key:
|
||||
logger.info("FIRMS API key reloaded", extra={"alias": self._api_key_alias})
|
||||
else:
|
||||
logger.warning(
|
||||
"FIRMS API key not found after alias change",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"FIRMS config applied",
|
||||
extra={
|
||||
"region": region_dict,
|
||||
"satellites": self._satellites,
|
||||
"api_key_alias": self._api_key_alias,
|
||||
},
|
||||
)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""Initialize HTTP session, dedup tracker, and fetch API key."""
|
||||
# Fetch API key
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if not self._api_key:
|
||||
logger.error(
|
||||
"FIRMS API key not found - polling will be skipped until key is set",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
|
||||
# Initialize HTTP session
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
)
|
||||
|
||||
# Initialize dedup tracker (shared sqlite DB with NWS)
|
||||
self._db = sqlite3.connect(str(self._cursor_db_path))
|
||||
self._db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS published_ids (
|
||||
adapter TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (adapter, event_id)
|
||||
)
|
||||
""")
|
||||
self._db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS published_ids_last_seen
|
||||
ON published_ids (last_seen)
|
||||
""")
|
||||
self._db.commit()
|
||||
|
||||
# Sweep old entries on startup (48h for FIRMS)
|
||||
self.sweep_old_ids()
|
||||
|
||||
logger.info(
|
||||
"FIRMS adapter started",
|
||||
extra={
|
||||
"region": {
|
||||
"north": self.region.north,
|
||||
"south": self.region.south,
|
||||
"east": self.region.east,
|
||||
"west": self.region.west,
|
||||
} if self.region else None,
|
||||
"satellites": self._satellites,
|
||||
"api_key_present": self._api_key is not None,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Close HTTP session and database."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._db:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
logger.info("FIRMS adapter shut down")
|
||||
|
||||
def is_published(self, stable_id: str) -> bool:
|
||||
"""Check if an event has already been published."""
|
||||
if not self._db:
|
||||
return False
|
||||
cur = self._db.execute(
|
||||
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
|
||||
(self.name, stable_id),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def mark_published(self, stable_id: str) -> None:
|
||||
"""Mark an event as published."""
|
||||
if not self._db:
|
||||
return
|
||||
self._db.execute(
|
||||
"""
|
||||
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (adapter, event_id) DO UPDATE SET
|
||||
last_seen = CURRENT_TIMESTAMP
|
||||
""",
|
||||
(self.name, stable_id),
|
||||
)
|
||||
self._db.commit()
|
||||
|
||||
def sweep_old_ids(self) -> int:
|
||||
"""Remove published_ids older than 48 hours. Returns count deleted."""
|
||||
if not self._db:
|
||||
return 0
|
||||
cur = self._db.execute(
|
||||
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-48 hours')",
|
||||
(self.name,),
|
||||
)
|
||||
self._db.commit()
|
||||
count = cur.rowcount
|
||||
if count > 0:
|
||||
logger.info("FIRMS swept old dedup entries", extra={"count": count})
|
||||
return count
|
||||
|
||||
def _build_stable_id(
|
||||
self, satellite: str, acq_date: str, acq_time: str, lat: float, lon: float
|
||||
) -> str:
|
||||
"""Build stable ID for deduplication."""
|
||||
# Round lat/lon to 0.001 degrees to handle floating-point comparison
|
||||
lat_rounded = round(lat, 3)
|
||||
lon_rounded = round(lon, 3)
|
||||
return f"{satellite}:{acq_date}:{acq_time}:{lat_rounded}:{lon_rounded}"
|
||||
|
||||
def _build_url(self, satellite: str) -> str | None:
|
||||
"""Build FIRMS API URL for a satellite."""
|
||||
if not self._api_key or not self.region:
|
||||
return None
|
||||
|
||||
# Area format: west,south,east,north
|
||||
area = f"{self.region.west},{self.region.south},{self.region.east},{self.region.north}"
|
||||
return f"{FIRMS_API_BASE}/{self._api_key}/{satellite}/{area}/1"
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=2, max=30),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _fetch_csv(self, url: str) -> str:
|
||||
"""Fetch CSV data from FIRMS API."""
|
||||
if not self._session:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
async with self._session.get(url) as resp:
|
||||
# Check for error responses
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
if "text/html" in content_type:
|
||||
text = await resp.text()
|
||||
logger.error(
|
||||
"FIRMS returned HTML (likely auth error)",
|
||||
extra={"status": resp.status, "preview": text[:200]},
|
||||
)
|
||||
raise ValueError("FIRMS returned HTML instead of CSV")
|
||||
|
||||
resp.raise_for_status()
|
||||
return await resp.text()
|
||||
|
||||
def _parse_csv(self, csv_text: str, satellite: str) -> list[dict[str, Any]]:
|
||||
"""Parse FIRMS CSV response into list of dicts."""
|
||||
rows = []
|
||||
reader = csv.DictReader(StringIO(csv_text))
|
||||
|
||||
for row in reader:
|
||||
try:
|
||||
# Parse required fields
|
||||
lat = float(row["latitude"])
|
||||
lon = float(row["longitude"])
|
||||
acq_date = row["acq_date"]
|
||||
acq_time = row["acq_time"]
|
||||
confidence_raw = row.get("confidence", "n").lower()
|
||||
confidence = CONFIDENCE_MAP.get(confidence_raw, "nominal")
|
||||
|
||||
rows.append({
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"bright_ti4": float(row.get("bright_ti4", 0)) if row.get("bright_ti4") else None,
|
||||
"bright_ti5": float(row.get("bright_ti5", 0)) if row.get("bright_ti5") else None,
|
||||
"scan": float(row.get("scan", 0)) if row.get("scan") else None,
|
||||
"track": float(row.get("track", 0)) if row.get("track") else None,
|
||||
"acq_date": acq_date,
|
||||
"acq_time": acq_time,
|
||||
"satellite": row.get("satellite", satellite),
|
||||
"instrument": row.get("instrument", "VIIRS"),
|
||||
"confidence": confidence,
|
||||
"confidence_raw": confidence_raw,
|
||||
"version": row.get("version", ""),
|
||||
"frp": float(row.get("frp", 0)) if row.get("frp") else None,
|
||||
"daynight": row.get("daynight", ""),
|
||||
})
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse FIRMS row",
|
||||
extra={"error": str(e), "row": dict(row)},
|
||||
)
|
||||
continue
|
||||
|
||||
return rows
|
||||
|
||||
def _row_to_event(self, row: dict[str, Any], satellite: str) -> Event:
|
||||
"""Convert a parsed CSV row to an Event."""
|
||||
satellite_short = SATELLITE_SHORT.get(satellite, satellite.lower().replace("_nrt", ""))
|
||||
confidence = row["confidence"]
|
||||
severity = SEVERITY_MAP.get(confidence, 1)
|
||||
|
||||
# Parse acquisition time
|
||||
acq_date = row["acq_date"]
|
||||
acq_time = row["acq_time"]
|
||||
# acq_time is HHMM format
|
||||
try:
|
||||
time = datetime.strptime(
|
||||
f"{acq_date} {acq_time}", "%Y-%m-%d %H%M"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
time = datetime.now(timezone.utc)
|
||||
|
||||
lat = row["latitude"]
|
||||
lon = row["longitude"]
|
||||
|
||||
# Build stable ID
|
||||
stable_id = self._build_stable_id(satellite, acq_date, acq_time, lat, lon)
|
||||
|
||||
geo = Geo(
|
||||
centroid=(lon, lat), # GeoJSON order: lon, lat
|
||||
bbox=(lon, lat, lon, lat), # Point bbox
|
||||
regions=[],
|
||||
primary_region=None,
|
||||
)
|
||||
|
||||
return Event(
|
||||
id=stable_id,
|
||||
source="central/adapters/firms",
|
||||
category=f"fire.hotspot.{satellite_short}.{confidence}",
|
||||
time=time,
|
||||
expires=None,
|
||||
severity=severity,
|
||||
geo=geo,
|
||||
data=row,
|
||||
)
|
||||
|
||||
async def poll(self) -> AsyncIterator[Event]:
|
||||
"""Poll FIRMS API for fire hotspots."""
|
||||
# Check API key
|
||||
if not self._api_key:
|
||||
# Try to fetch again in case it was added
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if not self._api_key:
|
||||
logger.warning(
|
||||
"FIRMS API key still not available, skipping poll",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
return
|
||||
|
||||
if not self.region:
|
||||
logger.warning("FIRMS region not configured, skipping poll")
|
||||
return
|
||||
|
||||
# Sweep old dedup entries periodically
|
||||
self.sweep_old_ids()
|
||||
|
||||
total_features = 0
|
||||
total_new = 0
|
||||
|
||||
for satellite in self._satellites:
|
||||
url = self._build_url(satellite)
|
||||
if not url:
|
||||
continue
|
||||
|
||||
try:
|
||||
csv_text = await self._fetch_csv(url)
|
||||
rows = self._parse_csv(csv_text, satellite)
|
||||
feature_count = len(rows)
|
||||
total_features += feature_count
|
||||
|
||||
new_count = 0
|
||||
for row in rows:
|
||||
stable_id = self._build_stable_id(
|
||||
satellite,
|
||||
row["acq_date"],
|
||||
row["acq_time"],
|
||||
row["latitude"],
|
||||
row["longitude"],
|
||||
)
|
||||
|
||||
if self.is_published(stable_id):
|
||||
continue
|
||||
|
||||
event = self._row_to_event(row, satellite)
|
||||
yield event
|
||||
self.mark_published(stable_id)
|
||||
new_count += 1
|
||||
|
||||
total_new += new_count
|
||||
logger.info(
|
||||
"FIRMS satellite poll completed",
|
||||
extra={
|
||||
"satellite": satellite,
|
||||
"feature_count": feature_count,
|
||||
"new_count": new_count,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FIRMS poll failed for satellite",
|
||||
extra={"satellite": satellite, "error": str(e)},
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"FIRMS poll completed",
|
||||
extra={
|
||||
"total_features": total_features,
|
||||
"total_new": total_new,
|
||||
"satellites": self._satellites,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def subject_for_fire_hotspot(ev: Event) -> str:
|
||||
"""Compute the NATS subject for a fire hotspot event.
|
||||
|
||||
Subject format: central.fire.hotspot.<satellite>.<confidence>
|
||||
|
||||
The category already contains the satellite and confidence info,
|
||||
so we just prefix with 'central.'.
|
||||
"""
|
||||
# category is "fire.hotspot.<satellite>.<confidence>"
|
||||
return f"central.{ev.category}"
|
||||
"""FIRMS (Fire Information for Resource Management System) adapter."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from central.adapter import SourceAdapter
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.config_store import ConfigStore
|
||||
from central.models import Event, Geo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FIRMS API base URL
|
||||
FIRMS_API_BASE = "https://firms.modaps.eosdis.nasa.gov/api/area/csv"
|
||||
|
||||
# Satellite name mapping
|
||||
SATELLITE_SHORT = {
|
||||
"VIIRS_SNPP_NRT": "viirs_snpp",
|
||||
"VIIRS_NOAA20_NRT": "viirs_noaa20",
|
||||
"VIIRS_NOAA21_NRT": "viirs_noaa21",
|
||||
}
|
||||
|
||||
# Confidence mapping
|
||||
CONFIDENCE_MAP = {
|
||||
"l": "low",
|
||||
"n": "nominal",
|
||||
"h": "high",
|
||||
}
|
||||
|
||||
# Severity mapping (confidence -> severity level)
|
||||
SEVERITY_MAP = {
|
||||
"high": 3,
|
||||
"nominal": 2,
|
||||
"low": 1,
|
||||
}
|
||||
|
||||
|
||||
class FIRMSAdapter(SourceAdapter):
|
||||
"""NASA FIRMS fire hotspot adapter."""
|
||||
|
||||
name = "firms"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AdapterConfig,
|
||||
config_store: ConfigStore,
|
||||
cursor_db_path: Path,
|
||||
) -> None:
|
||||
self._config_store = config_store
|
||||
self._cursor_db_path = cursor_db_path
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._db: sqlite3.Connection | None = None
|
||||
self._api_key: str | None = None
|
||||
|
||||
# Extract settings from config
|
||||
self._api_key_alias: str = config.settings.get("api_key_alias", "firms")
|
||||
self._satellites: list[str] = config.settings.get(
|
||||
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
|
||||
)
|
||||
|
||||
# Parse region from settings
|
||||
region_dict = config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region: RegionConfig | None = RegionConfig(**region_dict)
|
||||
else:
|
||||
self.region = None
|
||||
|
||||
async def apply_config(self, new_config: AdapterConfig) -> None:
|
||||
"""Apply new configuration from hot-reload."""
|
||||
old_alias = self._api_key_alias
|
||||
|
||||
# Update settings
|
||||
self._api_key_alias = new_config.settings.get("api_key_alias", "firms")
|
||||
self._satellites = new_config.settings.get(
|
||||
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
|
||||
)
|
||||
|
||||
# Update region
|
||||
region_dict = new_config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region = RegionConfig(**region_dict)
|
||||
else:
|
||||
self.region = None
|
||||
|
||||
# If API key alias changed, re-fetch the key
|
||||
if self._api_key_alias != old_alias:
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if self._api_key:
|
||||
logger.info("FIRMS API key reloaded", extra={"alias": self._api_key_alias})
|
||||
else:
|
||||
logger.warning(
|
||||
"FIRMS API key not found after alias change",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"FIRMS config applied",
|
||||
extra={
|
||||
"region": region_dict,
|
||||
"satellites": self._satellites,
|
||||
"api_key_alias": self._api_key_alias,
|
||||
},
|
||||
)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""Initialize HTTP session, dedup tracker, and fetch API key."""
|
||||
# Fetch API key
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if not self._api_key:
|
||||
logger.error(
|
||||
"FIRMS API key not found - polling will be skipped until key is set",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
|
||||
# Initialize HTTP session
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
)
|
||||
|
||||
# Initialize dedup tracker (shared sqlite DB with NWS)
|
||||
self._db = sqlite3.connect(str(self._cursor_db_path))
|
||||
self._db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS published_ids (
|
||||
adapter TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (adapter, event_id)
|
||||
)
|
||||
""")
|
||||
self._db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS published_ids_last_seen
|
||||
ON published_ids (last_seen)
|
||||
""")
|
||||
self._db.commit()
|
||||
|
||||
# Sweep old entries on startup (48h for FIRMS)
|
||||
self.sweep_old_ids()
|
||||
|
||||
logger.info(
|
||||
"FIRMS adapter started",
|
||||
extra={
|
||||
"region": {
|
||||
"north": self.region.north,
|
||||
"south": self.region.south,
|
||||
"east": self.region.east,
|
||||
"west": self.region.west,
|
||||
} if self.region else None,
|
||||
"satellites": self._satellites,
|
||||
"api_key_present": self._api_key is not None,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Close HTTP session and database."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._db:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
logger.info("FIRMS adapter shut down")
|
||||
|
||||
def is_published(self, stable_id: str) -> bool:
|
||||
"""Check if an event has already been published."""
|
||||
if not self._db:
|
||||
return False
|
||||
cur = self._db.execute(
|
||||
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
|
||||
(self.name, stable_id),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def mark_published(self, stable_id: str) -> None:
|
||||
"""Mark an event as published."""
|
||||
if not self._db:
|
||||
return
|
||||
self._db.execute(
|
||||
"""
|
||||
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (adapter, event_id) DO UPDATE SET
|
||||
last_seen = CURRENT_TIMESTAMP
|
||||
""",
|
||||
(self.name, stable_id),
|
||||
)
|
||||
self._db.commit()
|
||||
|
||||
def sweep_old_ids(self) -> int:
|
||||
"""Remove published_ids older than 48 hours. Returns count deleted."""
|
||||
if not self._db:
|
||||
return 0
|
||||
cur = self._db.execute(
|
||||
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-48 hours')",
|
||||
(self.name,),
|
||||
)
|
||||
self._db.commit()
|
||||
count = cur.rowcount
|
||||
if count > 0:
|
||||
logger.info("FIRMS swept old dedup entries", extra={"count": count})
|
||||
return count
|
||||
|
||||
def _build_stable_id(
|
||||
self, satellite: str, acq_date: str, acq_time: str, lat: float, lon: float
|
||||
) -> str:
|
||||
"""Build stable ID for deduplication."""
|
||||
# Round lat/lon to 0.001 degrees to handle floating-point comparison
|
||||
lat_rounded = round(lat, 3)
|
||||
lon_rounded = round(lon, 3)
|
||||
return f"{satellite}:{acq_date}:{acq_time}:{lat_rounded}:{lon_rounded}"
|
||||
|
||||
def _build_url(self, satellite: str) -> str | None:
|
||||
"""Build FIRMS API URL for a satellite."""
|
||||
if not self._api_key or not self.region:
|
||||
return None
|
||||
|
||||
# Area format: west,south,east,north
|
||||
area = f"{self.region.west},{self.region.south},{self.region.east},{self.region.north}"
|
||||
return f"{FIRMS_API_BASE}/{self._api_key}/{satellite}/{area}/1"
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=2, max=30),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _fetch_csv(self, url: str) -> str:
|
||||
"""Fetch CSV data from FIRMS API."""
|
||||
if not self._session:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
async with self._session.get(url) as resp:
|
||||
# Check for error responses
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
if "text/html" in content_type:
|
||||
text = await resp.text()
|
||||
logger.error(
|
||||
"FIRMS returned HTML (likely auth error)",
|
||||
extra={"status": resp.status, "preview": text[:200]},
|
||||
)
|
||||
raise ValueError("FIRMS returned HTML instead of CSV")
|
||||
|
||||
resp.raise_for_status()
|
||||
return await resp.text()
|
||||
|
||||
def _parse_csv(self, csv_text: str, satellite: str) -> list[dict[str, Any]]:
|
||||
"""Parse FIRMS CSV response into list of dicts."""
|
||||
rows = []
|
||||
reader = csv.DictReader(StringIO(csv_text))
|
||||
|
||||
for row in reader:
|
||||
try:
|
||||
# Parse required fields
|
||||
lat = float(row["latitude"])
|
||||
lon = float(row["longitude"])
|
||||
acq_date = row["acq_date"]
|
||||
acq_time = row["acq_time"]
|
||||
confidence_raw = row.get("confidence", "n").lower()
|
||||
confidence = CONFIDENCE_MAP.get(confidence_raw, "nominal")
|
||||
|
||||
rows.append({
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"bright_ti4": float(row.get("bright_ti4", 0)) if row.get("bright_ti4") else None,
|
||||
"bright_ti5": float(row.get("bright_ti5", 0)) if row.get("bright_ti5") else None,
|
||||
"scan": float(row.get("scan", 0)) if row.get("scan") else None,
|
||||
"track": float(row.get("track", 0)) if row.get("track") else None,
|
||||
"acq_date": acq_date,
|
||||
"acq_time": acq_time,
|
||||
"satellite": row.get("satellite", satellite),
|
||||
"instrument": row.get("instrument", "VIIRS"),
|
||||
"confidence": confidence,
|
||||
"confidence_raw": confidence_raw,
|
||||
"version": row.get("version", ""),
|
||||
"frp": float(row.get("frp", 0)) if row.get("frp") else None,
|
||||
"daynight": row.get("daynight", ""),
|
||||
})
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse FIRMS row",
|
||||
extra={"error": str(e), "row": dict(row)},
|
||||
)
|
||||
continue
|
||||
|
||||
return rows
|
||||
|
||||
def _row_to_event(self, row: dict[str, Any], satellite: str) -> Event:
|
||||
"""Convert a parsed CSV row to an Event."""
|
||||
satellite_short = SATELLITE_SHORT.get(satellite, satellite.lower().replace("_nrt", ""))
|
||||
confidence = row["confidence"]
|
||||
severity = SEVERITY_MAP.get(confidence, 1)
|
||||
|
||||
# Parse acquisition time
|
||||
acq_date = row["acq_date"]
|
||||
acq_time = row["acq_time"]
|
||||
# acq_time is HHMM format
|
||||
try:
|
||||
time = datetime.strptime(
|
||||
f"{acq_date} {acq_time}", "%Y-%m-%d %H%M"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
time = datetime.now(timezone.utc)
|
||||
|
||||
lat = row["latitude"]
|
||||
lon = row["longitude"]
|
||||
|
||||
# Build stable ID
|
||||
stable_id = self._build_stable_id(satellite, acq_date, acq_time, lat, lon)
|
||||
|
||||
geo = Geo(
|
||||
centroid=(lon, lat), # GeoJSON order: lon, lat
|
||||
bbox=(lon, lat, lon, lat), # Point bbox
|
||||
regions=[],
|
||||
primary_region=None,
|
||||
)
|
||||
|
||||
return Event(
|
||||
id=stable_id,
|
||||
source="central/adapters/firms",
|
||||
category=f"fire.hotspot.{satellite_short}.{confidence}",
|
||||
time=time,
|
||||
expires=None,
|
||||
severity=severity,
|
||||
geo=geo,
|
||||
data=row,
|
||||
)
|
||||
|
||||
async def poll(self) -> AsyncIterator[Event]:
|
||||
"""Poll FIRMS API for fire hotspots."""
|
||||
# Check API key
|
||||
if not self._api_key:
|
||||
# Try to fetch again in case it was added
|
||||
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
|
||||
if not self._api_key:
|
||||
logger.warning(
|
||||
"FIRMS API key still not available, skipping poll",
|
||||
extra={"alias": self._api_key_alias},
|
||||
)
|
||||
return
|
||||
|
||||
if not self.region:
|
||||
logger.warning("FIRMS region not configured, skipping poll")
|
||||
return
|
||||
|
||||
# Sweep old dedup entries periodically
|
||||
self.sweep_old_ids()
|
||||
|
||||
total_features = 0
|
||||
total_new = 0
|
||||
|
||||
for satellite in self._satellites:
|
||||
url = self._build_url(satellite)
|
||||
if not url:
|
||||
continue
|
||||
|
||||
try:
|
||||
csv_text = await self._fetch_csv(url)
|
||||
rows = self._parse_csv(csv_text, satellite)
|
||||
feature_count = len(rows)
|
||||
total_features += feature_count
|
||||
|
||||
new_count = 0
|
||||
for row in rows:
|
||||
stable_id = self._build_stable_id(
|
||||
satellite,
|
||||
row["acq_date"],
|
||||
row["acq_time"],
|
||||
row["latitude"],
|
||||
row["longitude"],
|
||||
)
|
||||
|
||||
if self.is_published(stable_id):
|
||||
continue
|
||||
|
||||
event = self._row_to_event(row, satellite)
|
||||
yield event
|
||||
self.mark_published(stable_id)
|
||||
new_count += 1
|
||||
|
||||
total_new += new_count
|
||||
logger.info(
|
||||
"FIRMS satellite poll completed",
|
||||
extra={
|
||||
"satellite": satellite,
|
||||
"feature_count": feature_count,
|
||||
"new_count": new_count,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FIRMS poll failed for satellite",
|
||||
extra={"satellite": satellite, "error": str(e)},
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"FIRMS poll completed",
|
||||
extra={
|
||||
"total_features": total_features,
|
||||
"total_new": total_new,
|
||||
"satellites": self._satellites,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def subject_for_fire_hotspot(ev: Event) -> str:
|
||||
"""Compute the NATS subject for a fire hotspot event.
|
||||
|
||||
Subject format: central.fire.hotspot.<satellite>.<confidence>
|
||||
|
||||
The category already contains the satellite and confidence info,
|
||||
so we just prefix with 'central.'.
|
||||
"""
|
||||
# category is "fire.hotspot.<satellite>.<confidence>"
|
||||
return f"central.{ev.category}"
|
||||
|
|
|
|||
|
|
@ -1,400 +1,400 @@
|
|||
"""USGS Earthquake Hazards Program adapter."""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from shapely.geometry import Point, box as shapely_box
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from central.adapter import SourceAdapter
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.config_store import ConfigStore
|
||||
from central.models import Event, Geo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# USGS GeoJSON feed base URL
|
||||
USGS_FEED_BASE = "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary"
|
||||
|
||||
# Valid feed options
|
||||
VALID_FEEDS = {"all_hour", "all_day", "all_week", "all_month"}
|
||||
|
||||
|
||||
def magnitude_tier(mag: float) -> str:
|
||||
"""Classify magnitude into USGS-style tier."""
|
||||
if mag < 3.0:
|
||||
return "minor"
|
||||
if mag < 4.0:
|
||||
return "light"
|
||||
if mag < 5.0:
|
||||
return "moderate"
|
||||
if mag < 6.0:
|
||||
return "strong"
|
||||
if mag < 7.0:
|
||||
return "major"
|
||||
return "great"
|
||||
|
||||
|
||||
def magnitude_to_severity(mag: float) -> int:
|
||||
"""Map magnitude to severity level (0-5)."""
|
||||
if mag < 3.0:
|
||||
return 0
|
||||
if mag < 4.0:
|
||||
return 1
|
||||
if mag < 5.0:
|
||||
return 2
|
||||
if mag < 6.0:
|
||||
return 3
|
||||
if mag < 7.0:
|
||||
return 4
|
||||
return 5
|
||||
|
||||
|
||||
class USGSQuakeAdapter(SourceAdapter):
|
||||
"""USGS Earthquake Hazards Program adapter."""
|
||||
|
||||
name = "usgs_quake"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AdapterConfig,
|
||||
config_store: ConfigStore, # Unused, accepted for signature uniformity
|
||||
cursor_db_path: Path,
|
||||
) -> None:
|
||||
self._cursor_db_path = cursor_db_path
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._db: sqlite3.Connection | None = None
|
||||
|
||||
# Extract settings from config
|
||||
self._feed: str = config.settings.get("feed", "all_hour")
|
||||
if self._feed not in VALID_FEEDS:
|
||||
logger.warning(
|
||||
"Invalid feed setting, using all_hour",
|
||||
extra={"feed": self._feed, "valid": list(VALID_FEEDS)},
|
||||
)
|
||||
self._feed = "all_hour"
|
||||
|
||||
# Parse region from settings
|
||||
region_dict = config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region: RegionConfig | None = RegionConfig(**region_dict)
|
||||
self._region_box = shapely_box(
|
||||
self.region.west,
|
||||
self.region.south,
|
||||
self.region.east,
|
||||
self.region.north,
|
||||
)
|
||||
else:
|
||||
self.region = None
|
||||
self._region_box = None
|
||||
|
||||
async def apply_config(self, new_config: AdapterConfig) -> None:
|
||||
"""Apply new configuration from hot-reload."""
|
||||
# Update feed
|
||||
new_feed = new_config.settings.get("feed", "all_hour")
|
||||
if new_feed in VALID_FEEDS:
|
||||
self._feed = new_feed
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid feed in new config, keeping current",
|
||||
extra={"new_feed": new_feed, "current": self._feed},
|
||||
)
|
||||
|
||||
# Update region
|
||||
region_dict = new_config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region = RegionConfig(**region_dict)
|
||||
self._region_box = shapely_box(
|
||||
self.region.west,
|
||||
self.region.south,
|
||||
self.region.east,
|
||||
self.region.north,
|
||||
)
|
||||
else:
|
||||
self.region = None
|
||||
self._region_box = None
|
||||
|
||||
logger.info(
|
||||
"USGS quake config applied",
|
||||
extra={
|
||||
"region": region_dict,
|
||||
"feed": self._feed,
|
||||
},
|
||||
)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""Initialize HTTP session and dedup tracker."""
|
||||
# Initialize HTTP session
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={"User-Agent": "Central/1.0 (earthquake monitoring)"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
|
||||
# Initialize dedup tracker (shared sqlite DB)
|
||||
self._db = sqlite3.connect(str(self._cursor_db_path))
|
||||
self._db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS published_ids (
|
||||
adapter TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (adapter, event_id)
|
||||
)
|
||||
""")
|
||||
self._db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS published_ids_last_seen
|
||||
ON published_ids (last_seen)
|
||||
""")
|
||||
self._db.commit()
|
||||
|
||||
# Sweep old entries on startup (7 days for quakes)
|
||||
self.sweep_old_ids()
|
||||
|
||||
logger.info(
|
||||
"USGS quake adapter started",
|
||||
extra={
|
||||
"region": {
|
||||
"north": self.region.north,
|
||||
"south": self.region.south,
|
||||
"east": self.region.east,
|
||||
"west": self.region.west,
|
||||
} if self.region else None,
|
||||
"feed": self._feed,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Close HTTP session and database."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._db:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
logger.info("USGS quake adapter shut down")
|
||||
|
||||
def is_published(self, event_id: str) -> bool:
|
||||
"""Check if an event has already been published."""
|
||||
if not self._db:
|
||||
return False
|
||||
cur = self._db.execute(
|
||||
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
|
||||
(self.name, event_id),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def mark_published(self, event_id: str) -> None:
|
||||
"""Mark an event as published."""
|
||||
if not self._db:
|
||||
return
|
||||
self._db.execute(
|
||||
"""
|
||||
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (adapter, event_id) DO UPDATE SET
|
||||
last_seen = CURRENT_TIMESTAMP
|
||||
""",
|
||||
(self.name, event_id),
|
||||
)
|
||||
self._db.commit()
|
||||
|
||||
def sweep_old_ids(self) -> int:
|
||||
"""Remove published_ids older than 7 days. Returns count deleted."""
|
||||
if not self._db:
|
||||
return 0
|
||||
cur = self._db.execute(
|
||||
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-7 days')",
|
||||
(self.name,),
|
||||
)
|
||||
self._db.commit()
|
||||
count = cur.rowcount
|
||||
if count > 0:
|
||||
logger.info("USGS quake swept old dedup entries", extra={"count": count})
|
||||
return count
|
||||
|
||||
def _build_url(self) -> str:
|
||||
"""Build USGS GeoJSON feed URL."""
|
||||
return f"{USGS_FEED_BASE}/{self._feed}.geojson"
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=1, max=15),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _fetch_geojson(self) -> dict[str, Any]:
|
||||
"""Fetch GeoJSON data from USGS."""
|
||||
if not self._session:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
url = self._build_url()
|
||||
async with self._session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
def _point_in_region(self, lon: float, lat: float) -> bool:
|
||||
"""Check if point intersects region bbox using shapely."""
|
||||
if self._region_box is None:
|
||||
return True
|
||||
point = Point(lon, lat)
|
||||
return self._region_box.intersects(point)
|
||||
|
||||
def _feature_to_event(self, feature: dict[str, Any]) -> Event | None:
|
||||
"""Convert a GeoJSON feature to an Event."""
|
||||
props = feature.get("properties", {})
|
||||
geometry = feature.get("geometry", {})
|
||||
coords = geometry.get("coordinates", [])
|
||||
|
||||
# Validate required fields
|
||||
event_id = feature.get("id")
|
||||
if not event_id:
|
||||
logger.warning("Feature missing id", extra={"properties": props})
|
||||
return None
|
||||
|
||||
# Get magnitude - skip if null/missing (PM decision)
|
||||
mag = props.get("mag")
|
||||
if mag is None:
|
||||
logger.debug(
|
||||
"Skipping event with null magnitude",
|
||||
extra={"id": event_id, "place": props.get("place")},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
mag = float(mag)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid magnitude value",
|
||||
extra={"id": event_id, "mag": mag},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get coordinates [lon, lat, depth]
|
||||
if len(coords) < 2:
|
||||
logger.warning("Feature missing coordinates", extra={"id": event_id})
|
||||
return None
|
||||
|
||||
lon, lat = coords[0], coords[1]
|
||||
depth = coords[2] if len(coords) > 2 else None
|
||||
|
||||
# Region filter
|
||||
if not self._point_in_region(lon, lat):
|
||||
return None
|
||||
|
||||
# Parse event time (milliseconds since epoch)
|
||||
time_ms = props.get("time")
|
||||
if time_ms is not None:
|
||||
try:
|
||||
event_time = datetime.fromtimestamp(time_ms / 1000, tz=timezone.utc)
|
||||
except (TypeError, ValueError, OSError):
|
||||
event_time = datetime.now(timezone.utc)
|
||||
else:
|
||||
event_time = datetime.now(timezone.utc)
|
||||
|
||||
# Build tier and severity
|
||||
tier = magnitude_tier(mag)
|
||||
severity = magnitude_to_severity(mag)
|
||||
|
||||
# Build geo
|
||||
geo = Geo(
|
||||
centroid=(lon, lat),
|
||||
bbox=(lon, lat, lon, lat),
|
||||
regions=[],
|
||||
primary_region=None,
|
||||
)
|
||||
|
||||
# Build data payload
|
||||
data = {
|
||||
"magnitude": mag,
|
||||
"place": props.get("place"),
|
||||
"time_ms": time_ms,
|
||||
"updated_ms": props.get("updated"),
|
||||
"tz": props.get("tz"),
|
||||
"url": props.get("url"),
|
||||
"detail": props.get("detail"),
|
||||
"felt": props.get("felt"),
|
||||
"cdi": props.get("cdi"),
|
||||
"mmi": props.get("mmi"),
|
||||
"alert": props.get("alert"),
|
||||
"status": props.get("status"),
|
||||
"tsunami": props.get("tsunami"),
|
||||
"sig": props.get("sig"),
|
||||
"net": props.get("net"),
|
||||
"code": props.get("code"),
|
||||
"ids": props.get("ids"),
|
||||
"sources": props.get("sources"),
|
||||
"types": props.get("types"),
|
||||
"nst": props.get("nst"),
|
||||
"dmin": props.get("dmin"),
|
||||
"rms": props.get("rms"),
|
||||
"gap": props.get("gap"),
|
||||
"magType": props.get("magType"),
|
||||
"type": props.get("type"),
|
||||
"title": props.get("title"),
|
||||
"longitude": lon,
|
||||
"latitude": lat,
|
||||
"depth": depth,
|
||||
}
|
||||
|
||||
return Event(
|
||||
id=event_id,
|
||||
source="central/adapters/usgs_quake",
|
||||
category=f"quake.event.{tier}",
|
||||
time=event_time,
|
||||
expires=None,
|
||||
severity=severity,
|
||||
geo=geo,
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def poll(self) -> AsyncIterator[Event]:
|
||||
"""Poll USGS for earthquake data."""
|
||||
if not self.region:
|
||||
logger.warning("USGS quake region not configured, skipping poll")
|
||||
return
|
||||
|
||||
# Sweep old dedup entries periodically
|
||||
self.sweep_old_ids()
|
||||
|
||||
try:
|
||||
data = await self._fetch_geojson()
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch USGS data", extra={"error": str(e)})
|
||||
raise
|
||||
|
||||
features = data.get("features", [])
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
logger.info(
|
||||
"USGS quake poll completed",
|
||||
extra={
|
||||
"feature_count": len(features),
|
||||
"title": metadata.get("title"),
|
||||
"generated": metadata.get("generated"),
|
||||
},
|
||||
)
|
||||
|
||||
new_count = 0
|
||||
for feature in features:
|
||||
event = self._feature_to_event(feature)
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
if self.is_published(event.id):
|
||||
continue
|
||||
|
||||
yield event
|
||||
self.mark_published(event.id)
|
||||
new_count += 1
|
||||
|
||||
logger.info("USGS quake yielded events", extra={"count": new_count})
|
||||
"""USGS Earthquake Hazards Program adapter."""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from shapely.geometry import Point, box as shapely_box
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from central.adapter import SourceAdapter
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.config_store import ConfigStore
|
||||
from central.models import Event, Geo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# USGS GeoJSON feed base URL
|
||||
USGS_FEED_BASE = "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary"
|
||||
|
||||
# Valid feed options
|
||||
VALID_FEEDS = {"all_hour", "all_day", "all_week", "all_month"}
|
||||
|
||||
|
||||
def magnitude_tier(mag: float) -> str:
|
||||
"""Classify magnitude into USGS-style tier."""
|
||||
if mag < 3.0:
|
||||
return "minor"
|
||||
if mag < 4.0:
|
||||
return "light"
|
||||
if mag < 5.0:
|
||||
return "moderate"
|
||||
if mag < 6.0:
|
||||
return "strong"
|
||||
if mag < 7.0:
|
||||
return "major"
|
||||
return "great"
|
||||
|
||||
|
||||
def magnitude_to_severity(mag: float) -> int:
|
||||
"""Map magnitude to severity level (0-5)."""
|
||||
if mag < 3.0:
|
||||
return 0
|
||||
if mag < 4.0:
|
||||
return 1
|
||||
if mag < 5.0:
|
||||
return 2
|
||||
if mag < 6.0:
|
||||
return 3
|
||||
if mag < 7.0:
|
||||
return 4
|
||||
return 5
|
||||
|
||||
|
||||
class USGSQuakeAdapter(SourceAdapter):
|
||||
"""USGS Earthquake Hazards Program adapter."""
|
||||
|
||||
name = "usgs_quake"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AdapterConfig,
|
||||
config_store: ConfigStore, # Unused, accepted for signature uniformity
|
||||
cursor_db_path: Path,
|
||||
) -> None:
|
||||
self._cursor_db_path = cursor_db_path
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._db: sqlite3.Connection | None = None
|
||||
|
||||
# Extract settings from config
|
||||
self._feed: str = config.settings.get("feed", "all_hour")
|
||||
if self._feed not in VALID_FEEDS:
|
||||
logger.warning(
|
||||
"Invalid feed setting, using all_hour",
|
||||
extra={"feed": self._feed, "valid": list(VALID_FEEDS)},
|
||||
)
|
||||
self._feed = "all_hour"
|
||||
|
||||
# Parse region from settings
|
||||
region_dict = config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region: RegionConfig | None = RegionConfig(**region_dict)
|
||||
self._region_box = shapely_box(
|
||||
self.region.west,
|
||||
self.region.south,
|
||||
self.region.east,
|
||||
self.region.north,
|
||||
)
|
||||
else:
|
||||
self.region = None
|
||||
self._region_box = None
|
||||
|
||||
async def apply_config(self, new_config: AdapterConfig) -> None:
|
||||
"""Apply new configuration from hot-reload."""
|
||||
# Update feed
|
||||
new_feed = new_config.settings.get("feed", "all_hour")
|
||||
if new_feed in VALID_FEEDS:
|
||||
self._feed = new_feed
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid feed in new config, keeping current",
|
||||
extra={"new_feed": new_feed, "current": self._feed},
|
||||
)
|
||||
|
||||
# Update region
|
||||
region_dict = new_config.settings.get("region")
|
||||
if region_dict:
|
||||
self.region = RegionConfig(**region_dict)
|
||||
self._region_box = shapely_box(
|
||||
self.region.west,
|
||||
self.region.south,
|
||||
self.region.east,
|
||||
self.region.north,
|
||||
)
|
||||
else:
|
||||
self.region = None
|
||||
self._region_box = None
|
||||
|
||||
logger.info(
|
||||
"USGS quake config applied",
|
||||
extra={
|
||||
"region": region_dict,
|
||||
"feed": self._feed,
|
||||
},
|
||||
)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""Initialize HTTP session and dedup tracker."""
|
||||
# Initialize HTTP session
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={"User-Agent": "Central/1.0 (earthquake monitoring)"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
|
||||
# Initialize dedup tracker (shared sqlite DB)
|
||||
self._db = sqlite3.connect(str(self._cursor_db_path))
|
||||
self._db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS published_ids (
|
||||
adapter TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (adapter, event_id)
|
||||
)
|
||||
""")
|
||||
self._db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS published_ids_last_seen
|
||||
ON published_ids (last_seen)
|
||||
""")
|
||||
self._db.commit()
|
||||
|
||||
# Sweep old entries on startup (7 days for quakes)
|
||||
self.sweep_old_ids()
|
||||
|
||||
logger.info(
|
||||
"USGS quake adapter started",
|
||||
extra={
|
||||
"region": {
|
||||
"north": self.region.north,
|
||||
"south": self.region.south,
|
||||
"east": self.region.east,
|
||||
"west": self.region.west,
|
||||
} if self.region else None,
|
||||
"feed": self._feed,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Close HTTP session and database."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._db:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
logger.info("USGS quake adapter shut down")
|
||||
|
||||
def is_published(self, event_id: str) -> bool:
|
||||
"""Check if an event has already been published."""
|
||||
if not self._db:
|
||||
return False
|
||||
cur = self._db.execute(
|
||||
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
|
||||
(self.name, event_id),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def mark_published(self, event_id: str) -> None:
|
||||
"""Mark an event as published."""
|
||||
if not self._db:
|
||||
return
|
||||
self._db.execute(
|
||||
"""
|
||||
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (adapter, event_id) DO UPDATE SET
|
||||
last_seen = CURRENT_TIMESTAMP
|
||||
""",
|
||||
(self.name, event_id),
|
||||
)
|
||||
self._db.commit()
|
||||
|
||||
def sweep_old_ids(self) -> int:
|
||||
"""Remove published_ids older than 7 days. Returns count deleted."""
|
||||
if not self._db:
|
||||
return 0
|
||||
cur = self._db.execute(
|
||||
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-7 days')",
|
||||
(self.name,),
|
||||
)
|
||||
self._db.commit()
|
||||
count = cur.rowcount
|
||||
if count > 0:
|
||||
logger.info("USGS quake swept old dedup entries", extra={"count": count})
|
||||
return count
|
||||
|
||||
def _build_url(self) -> str:
|
||||
"""Build USGS GeoJSON feed URL."""
|
||||
return f"{USGS_FEED_BASE}/{self._feed}.geojson"
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=1, max=15),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError,)),
|
||||
reraise=True,
|
||||
)
|
||||
async def _fetch_geojson(self) -> dict[str, Any]:
|
||||
"""Fetch GeoJSON data from USGS."""
|
||||
if not self._session:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
url = self._build_url()
|
||||
async with self._session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
def _point_in_region(self, lon: float, lat: float) -> bool:
|
||||
"""Check if point intersects region bbox using shapely."""
|
||||
if self._region_box is None:
|
||||
return True
|
||||
point = Point(lon, lat)
|
||||
return self._region_box.intersects(point)
|
||||
|
||||
def _feature_to_event(self, feature: dict[str, Any]) -> Event | None:
|
||||
"""Convert a GeoJSON feature to an Event."""
|
||||
props = feature.get("properties", {})
|
||||
geometry = feature.get("geometry", {})
|
||||
coords = geometry.get("coordinates", [])
|
||||
|
||||
# Validate required fields
|
||||
event_id = feature.get("id")
|
||||
if not event_id:
|
||||
logger.warning("Feature missing id", extra={"properties": props})
|
||||
return None
|
||||
|
||||
# Get magnitude - skip if null/missing (PM decision)
|
||||
mag = props.get("mag")
|
||||
if mag is None:
|
||||
logger.debug(
|
||||
"Skipping event with null magnitude",
|
||||
extra={"id": event_id, "place": props.get("place")},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
mag = float(mag)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid magnitude value",
|
||||
extra={"id": event_id, "mag": mag},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get coordinates [lon, lat, depth]
|
||||
if len(coords) < 2:
|
||||
logger.warning("Feature missing coordinates", extra={"id": event_id})
|
||||
return None
|
||||
|
||||
lon, lat = coords[0], coords[1]
|
||||
depth = coords[2] if len(coords) > 2 else None
|
||||
|
||||
# Region filter
|
||||
if not self._point_in_region(lon, lat):
|
||||
return None
|
||||
|
||||
# Parse event time (milliseconds since epoch)
|
||||
time_ms = props.get("time")
|
||||
if time_ms is not None:
|
||||
try:
|
||||
event_time = datetime.fromtimestamp(time_ms / 1000, tz=timezone.utc)
|
||||
except (TypeError, ValueError, OSError):
|
||||
event_time = datetime.now(timezone.utc)
|
||||
else:
|
||||
event_time = datetime.now(timezone.utc)
|
||||
|
||||
# Build tier and severity
|
||||
tier = magnitude_tier(mag)
|
||||
severity = magnitude_to_severity(mag)
|
||||
|
||||
# Build geo
|
||||
geo = Geo(
|
||||
centroid=(lon, lat),
|
||||
bbox=(lon, lat, lon, lat),
|
||||
regions=[],
|
||||
primary_region=None,
|
||||
)
|
||||
|
||||
# Build data payload
|
||||
data = {
|
||||
"magnitude": mag,
|
||||
"place": props.get("place"),
|
||||
"time_ms": time_ms,
|
||||
"updated_ms": props.get("updated"),
|
||||
"tz": props.get("tz"),
|
||||
"url": props.get("url"),
|
||||
"detail": props.get("detail"),
|
||||
"felt": props.get("felt"),
|
||||
"cdi": props.get("cdi"),
|
||||
"mmi": props.get("mmi"),
|
||||
"alert": props.get("alert"),
|
||||
"status": props.get("status"),
|
||||
"tsunami": props.get("tsunami"),
|
||||
"sig": props.get("sig"),
|
||||
"net": props.get("net"),
|
||||
"code": props.get("code"),
|
||||
"ids": props.get("ids"),
|
||||
"sources": props.get("sources"),
|
||||
"types": props.get("types"),
|
||||
"nst": props.get("nst"),
|
||||
"dmin": props.get("dmin"),
|
||||
"rms": props.get("rms"),
|
||||
"gap": props.get("gap"),
|
||||
"magType": props.get("magType"),
|
||||
"type": props.get("type"),
|
||||
"title": props.get("title"),
|
||||
"longitude": lon,
|
||||
"latitude": lat,
|
||||
"depth": depth,
|
||||
}
|
||||
|
||||
return Event(
|
||||
id=event_id,
|
||||
source="central/adapters/usgs_quake",
|
||||
category=f"quake.event.{tier}",
|
||||
time=event_time,
|
||||
expires=None,
|
||||
severity=severity,
|
||||
geo=geo,
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def poll(self) -> AsyncIterator[Event]:
|
||||
"""Poll USGS for earthquake data."""
|
||||
if not self.region:
|
||||
logger.warning("USGS quake region not configured, skipping poll")
|
||||
return
|
||||
|
||||
# Sweep old dedup entries periodically
|
||||
self.sweep_old_ids()
|
||||
|
||||
try:
|
||||
data = await self._fetch_geojson()
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch USGS data", extra={"error": str(e)})
|
||||
raise
|
||||
|
||||
features = data.get("features", [])
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
logger.info(
|
||||
"USGS quake poll completed",
|
||||
extra={
|
||||
"feature_count": len(features),
|
||||
"title": metadata.get("title"),
|
||||
"generated": metadata.get("generated"),
|
||||
},
|
||||
)
|
||||
|
||||
new_count = 0
|
||||
for feature in features:
|
||||
event = self._feature_to_event(feature)
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
if self.is_published(event.id):
|
||||
continue
|
||||
|
||||
yield event
|
||||
self.mark_published(event.id)
|
||||
new_count += 1
|
||||
|
||||
logger.info("USGS quake yielded events", extra={"count": new_count})
|
||||
|
|
|
|||
|
|
@ -1,353 +1,353 @@
|
|||
"""Central archive consumer - JetStream to TimescaleDB."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
CONSUMER_NAME = "archive"
|
||||
STREAM_NAME = "CENTRAL_WX"
|
||||
SUBJECT_FILTER = "central.wx.>"
|
||||
BATCH_SIZE = 100
|
||||
FETCH_TIMEOUT = 5.0
|
||||
ACK_WAIT = 30
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON log formatter for structured logging."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_obj: dict[str, Any] = {
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
if record.exc_info:
|
||||
log_obj["exc"] = self.formatException(record.exc_info)
|
||||
for key in record.__dict__:
|
||||
if key not in (
|
||||
"name", "msg", "args", "created", "filename", "funcName",
|
||||
"levelname", "levelno", "lineno", "module", "msecs",
|
||||
"pathname", "process", "processName", "relativeCreated",
|
||||
"stack_info", "exc_info", "exc_text", "thread", "threadName",
|
||||
"taskName", "message",
|
||||
):
|
||||
log_obj[key] = record.__dict__[key]
|
||||
return json.dumps(log_obj)
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure JSON logging to stdout."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logging.root.handlers = [handler]
|
||||
logging.root.setLevel(logging.INFO)
|
||||
|
||||
|
||||
logger = logging.getLogger("central.archive")
|
||||
|
||||
|
||||
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
|
||||
"""Build PostGIS geometry from event geo data."""
|
||||
if not geo_data:
|
||||
return None
|
||||
|
||||
bbox = geo_data.get("bbox")
|
||||
centroid = geo_data.get("centroid")
|
||||
|
||||
if bbox and len(bbox) == 4:
|
||||
# Create polygon from bbox
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
return json.dumps({
|
||||
"type": "Polygon",
|
||||
"coordinates": [[
|
||||
[min_lon, min_lat],
|
||||
[max_lon, min_lat],
|
||||
[max_lon, max_lat],
|
||||
[min_lon, max_lat],
|
||||
[min_lon, min_lat],
|
||||
]]
|
||||
})
|
||||
elif centroid and len(centroid) == 2:
|
||||
# Create point from centroid
|
||||
return json.dumps({
|
||||
"type": "Point",
|
||||
"coordinates": centroid
|
||||
})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ArchiveConsumer:
|
||||
"""Archive consumer process."""
|
||||
|
||||
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
|
||||
self._nats_url = nats_url
|
||||
self._postgres_dsn = postgres_dsn
|
||||
self._nc: nats.NATS | None = None
|
||||
self._js: JetStreamContext | None = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to NATS and PostgreSQL."""
|
||||
self._nc = await nats.connect(self._nats_url)
|
||||
self._js = self._nc.jetstream()
|
||||
logger.info("Connected to NATS", extra={"url": self._nats_url})
|
||||
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self._postgres_dsn,
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
logger.info("Connected to PostgreSQL")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from NATS and PostgreSQL."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
if self._nc:
|
||||
await self._nc.drain()
|
||||
await self._nc.close()
|
||||
self._nc = None
|
||||
self._js = None
|
||||
logger.info("Disconnected")
|
||||
|
||||
async def _ensure_consumer(self) -> None:
|
||||
"""Ensure the durable consumer exists."""
|
||||
if not self._js:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
|
||||
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
|
||||
except nats.js.errors.NotFoundError:
|
||||
consumer_config = ConsumerConfig(
|
||||
durable_name=CONSUMER_NAME,
|
||||
deliver_policy=DeliverPolicy.ALL,
|
||||
ack_policy=AckPolicy.EXPLICIT,
|
||||
ack_wait=ACK_WAIT,
|
||||
filter_subject=SUBJECT_FILTER,
|
||||
)
|
||||
await self._js.add_consumer(STREAM_NAME, consumer_config)
|
||||
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
|
||||
|
||||
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
|
||||
"""Process a single message and insert into database."""
|
||||
try:
|
||||
envelope = json.loads(msg.data.decode())
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Invalid JSON in message", extra={"error": str(e)})
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
event_data = envelope.get("data", {})
|
||||
geo_data = event_data.get("geo")
|
||||
|
||||
event_id = envelope.get("id")
|
||||
source = event_data.get("source", "")
|
||||
category = event_data.get("category", "")
|
||||
time_str = event_data.get("time")
|
||||
expires_str = event_data.get("expires")
|
||||
severity = event_data.get("severity")
|
||||
regions = event_data.get("geo", {}).get("regions", [])
|
||||
primary_region = event_data.get("geo", {}).get("primary_region")
|
||||
|
||||
# Parse timestamps
|
||||
event_time = None
|
||||
if time_str:
|
||||
try:
|
||||
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
expires_time = None
|
||||
if expires_str:
|
||||
try:
|
||||
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not event_id or not event_time:
|
||||
logger.warning(
|
||||
"Message missing required fields",
|
||||
extra={"id": event_id, "time": time_str}
|
||||
)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
geom_json = _build_geom_sql(geo_data)
|
||||
|
||||
try:
|
||||
if geom_json:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6,
|
||||
ST_GeomFromGeoJSON($7), $8, $9, $10)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
geom = EXCLUDED.geom,
|
||||
regions = EXCLUDED.regions,
|
||||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
geom_json, regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
else:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
geom = EXCLUDED.geom,
|
||||
regions = EXCLUDED.regions,
|
||||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
|
||||
await msg.ack()
|
||||
logger.info("Archived event", extra={"id": event_id, "category": category})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to insert event",
|
||||
extra={"id": event_id, "error": str(e)}
|
||||
)
|
||||
# Don't ack - let it be redelivered
|
||||
|
||||
async def _consume_loop(self) -> None:
|
||||
"""Main consume loop."""
|
||||
if not self._js or not self._pool:
|
||||
return
|
||||
|
||||
await self._ensure_consumer()
|
||||
|
||||
sub = await self._js.pull_subscribe(
|
||||
SUBJECT_FILTER,
|
||||
durable=CONSUMER_NAME,
|
||||
stream=STREAM_NAME,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Subscribed to stream",
|
||||
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
|
||||
)
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
msgs = await sub.fetch(
|
||||
batch=BATCH_SIZE,
|
||||
timeout=FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
if msgs:
|
||||
async with self._pool.acquire() as conn:
|
||||
for msg in msgs:
|
||||
await self._process_message(msg, conn)
|
||||
|
||||
except nats.errors.TimeoutError:
|
||||
# No messages available, continue
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error in consume loop", extra={"error": str(e)})
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Consume loop stopped")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the consumer."""
|
||||
await self.connect()
|
||||
logger.info("Archive consumer ready")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the consume loop until shutdown."""
|
||||
await self._consume_loop()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the consumer gracefully."""
|
||||
logger.info("Archive consumer shutting down")
|
||||
self._shutdown_event.set()
|
||||
await self.disconnect()
|
||||
logger.info("Archive consumer stopped")
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
"""Async entry point."""
|
||||
setup_logging()
|
||||
|
||||
settings = get_settings()
|
||||
logger.info(
|
||||
"Archive starting",
|
||||
extra={
|
||||
"nats_url": settings.nats_url,
|
||||
|
||||
},
|
||||
)
|
||||
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url=settings.nats_url,
|
||||
postgres_dsn=settings.db_dsn,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def handle_signal() -> None:
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, handle_signal)
|
||||
|
||||
await consumer.start()
|
||||
|
||||
# Run consumer in background
|
||||
consume_task = asyncio.create_task(consumer.run())
|
||||
|
||||
# Wait for shutdown signal
|
||||
await shutdown_event.wait()
|
||||
|
||||
consumer._shutdown_event.set()
|
||||
consume_task.cancel()
|
||||
try:
|
||||
await consume_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await consumer.stop()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point."""
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""Central archive consumer - JetStream to TimescaleDB."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
CONSUMER_NAME = "archive"
|
||||
STREAM_NAME = "CENTRAL_WX"
|
||||
SUBJECT_FILTER = "central.wx.>"
|
||||
BATCH_SIZE = 100
|
||||
FETCH_TIMEOUT = 5.0
|
||||
ACK_WAIT = 30
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON log formatter for structured logging."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_obj: dict[str, Any] = {
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
if record.exc_info:
|
||||
log_obj["exc"] = self.formatException(record.exc_info)
|
||||
for key in record.__dict__:
|
||||
if key not in (
|
||||
"name", "msg", "args", "created", "filename", "funcName",
|
||||
"levelname", "levelno", "lineno", "module", "msecs",
|
||||
"pathname", "process", "processName", "relativeCreated",
|
||||
"stack_info", "exc_info", "exc_text", "thread", "threadName",
|
||||
"taskName", "message",
|
||||
):
|
||||
log_obj[key] = record.__dict__[key]
|
||||
return json.dumps(log_obj)
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure JSON logging to stdout."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logging.root.handlers = [handler]
|
||||
logging.root.setLevel(logging.INFO)
|
||||
|
||||
|
||||
logger = logging.getLogger("central.archive")
|
||||
|
||||
|
||||
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
|
||||
"""Build PostGIS geometry from event geo data."""
|
||||
if not geo_data:
|
||||
return None
|
||||
|
||||
bbox = geo_data.get("bbox")
|
||||
centroid = geo_data.get("centroid")
|
||||
|
||||
if bbox and len(bbox) == 4:
|
||||
# Create polygon from bbox
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
return json.dumps({
|
||||
"type": "Polygon",
|
||||
"coordinates": [[
|
||||
[min_lon, min_lat],
|
||||
[max_lon, min_lat],
|
||||
[max_lon, max_lat],
|
||||
[min_lon, max_lat],
|
||||
[min_lon, min_lat],
|
||||
]]
|
||||
})
|
||||
elif centroid and len(centroid) == 2:
|
||||
# Create point from centroid
|
||||
return json.dumps({
|
||||
"type": "Point",
|
||||
"coordinates": centroid
|
||||
})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ArchiveConsumer:
|
||||
"""Archive consumer process."""
|
||||
|
||||
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
|
||||
self._nats_url = nats_url
|
||||
self._postgres_dsn = postgres_dsn
|
||||
self._nc: nats.NATS | None = None
|
||||
self._js: JetStreamContext | None = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to NATS and PostgreSQL."""
|
||||
self._nc = await nats.connect(self._nats_url)
|
||||
self._js = self._nc.jetstream()
|
||||
logger.info("Connected to NATS", extra={"url": self._nats_url})
|
||||
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self._postgres_dsn,
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
logger.info("Connected to PostgreSQL")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from NATS and PostgreSQL."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
if self._nc:
|
||||
await self._nc.drain()
|
||||
await self._nc.close()
|
||||
self._nc = None
|
||||
self._js = None
|
||||
logger.info("Disconnected")
|
||||
|
||||
async def _ensure_consumer(self) -> None:
|
||||
"""Ensure the durable consumer exists."""
|
||||
if not self._js:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
|
||||
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
|
||||
except nats.js.errors.NotFoundError:
|
||||
consumer_config = ConsumerConfig(
|
||||
durable_name=CONSUMER_NAME,
|
||||
deliver_policy=DeliverPolicy.ALL,
|
||||
ack_policy=AckPolicy.EXPLICIT,
|
||||
ack_wait=ACK_WAIT,
|
||||
filter_subject=SUBJECT_FILTER,
|
||||
)
|
||||
await self._js.add_consumer(STREAM_NAME, consumer_config)
|
||||
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
|
||||
|
||||
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
|
||||
"""Process a single message and insert into database."""
|
||||
try:
|
||||
envelope = json.loads(msg.data.decode())
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Invalid JSON in message", extra={"error": str(e)})
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
event_data = envelope.get("data", {})
|
||||
geo_data = event_data.get("geo")
|
||||
|
||||
event_id = envelope.get("id")
|
||||
source = event_data.get("source", "")
|
||||
category = event_data.get("category", "")
|
||||
time_str = event_data.get("time")
|
||||
expires_str = event_data.get("expires")
|
||||
severity = event_data.get("severity")
|
||||
regions = event_data.get("geo", {}).get("regions", [])
|
||||
primary_region = event_data.get("geo", {}).get("primary_region")
|
||||
|
||||
# Parse timestamps
|
||||
event_time = None
|
||||
if time_str:
|
||||
try:
|
||||
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
expires_time = None
|
||||
if expires_str:
|
||||
try:
|
||||
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not event_id or not event_time:
|
||||
logger.warning(
|
||||
"Message missing required fields",
|
||||
extra={"id": event_id, "time": time_str}
|
||||
)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
geom_json = _build_geom_sql(geo_data)
|
||||
|
||||
try:
|
||||
if geom_json:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6,
|
||||
ST_GeomFromGeoJSON($7), $8, $9, $10)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
geom = EXCLUDED.geom,
|
||||
regions = EXCLUDED.regions,
|
||||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
geom_json, regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
else:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
geom = EXCLUDED.geom,
|
||||
regions = EXCLUDED.regions,
|
||||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
|
||||
await msg.ack()
|
||||
logger.info("Archived event", extra={"id": event_id, "category": category})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to insert event",
|
||||
extra={"id": event_id, "error": str(e)}
|
||||
)
|
||||
# Don't ack - let it be redelivered
|
||||
|
||||
async def _consume_loop(self) -> None:
|
||||
"""Main consume loop."""
|
||||
if not self._js or not self._pool:
|
||||
return
|
||||
|
||||
await self._ensure_consumer()
|
||||
|
||||
sub = await self._js.pull_subscribe(
|
||||
SUBJECT_FILTER,
|
||||
durable=CONSUMER_NAME,
|
||||
stream=STREAM_NAME,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Subscribed to stream",
|
||||
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
|
||||
)
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
msgs = await sub.fetch(
|
||||
batch=BATCH_SIZE,
|
||||
timeout=FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
if msgs:
|
||||
async with self._pool.acquire() as conn:
|
||||
for msg in msgs:
|
||||
await self._process_message(msg, conn)
|
||||
|
||||
except nats.errors.TimeoutError:
|
||||
# No messages available, continue
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error in consume loop", extra={"error": str(e)})
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Consume loop stopped")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the consumer."""
|
||||
await self.connect()
|
||||
logger.info("Archive consumer ready")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the consume loop until shutdown."""
|
||||
await self._consume_loop()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the consumer gracefully."""
|
||||
logger.info("Archive consumer shutting down")
|
||||
self._shutdown_event.set()
|
||||
await self.disconnect()
|
||||
logger.info("Archive consumer stopped")
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
"""Async entry point."""
|
||||
setup_logging()
|
||||
|
||||
settings = get_settings()
|
||||
logger.info(
|
||||
"Archive starting",
|
||||
extra={
|
||||
"nats_url": settings.nats_url,
|
||||
|
||||
},
|
||||
)
|
||||
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url=settings.nats_url,
|
||||
postgres_dsn=settings.db_dsn,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def handle_signal() -> None:
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, handle_signal)
|
||||
|
||||
await consumer.start()
|
||||
|
||||
# Run consumer in background
|
||||
consume_task = asyncio.create_task(consumer.run())
|
||||
|
||||
# Wait for shutdown signal
|
||||
await shutdown_event.wait()
|
||||
|
||||
consumer._shutdown_event.set()
|
||||
consume_task.cancel()
|
||||
try:
|
||||
await consume_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await consumer.stop()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point."""
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,75 +1,75 @@
|
|||
"""Central CLI commands."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
|
||||
async def config_store_check() -> int:
|
||||
"""Smoke test for config store connectivity.
|
||||
|
||||
Connects via bootstrap_config, lists adapters, and verifies crypto.
|
||||
Returns 0 on success, 1 on failure.
|
||||
"""
|
||||
from central.bootstrap_config import get_settings
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
settings = get_settings()
|
||||
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
|
||||
|
||||
try:
|
||||
store = await ConfigStore.create(settings.db_dsn)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to connect to database: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# List adapters
|
||||
adapters = await store.list_adapters()
|
||||
print(f"\nAdapters ({len(adapters)}):")
|
||||
for adapter in adapters:
|
||||
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
|
||||
print(f" settings: {adapter.settings}")
|
||||
|
||||
# Test crypto
|
||||
test_plaintext = b"config_store_check_test"
|
||||
try:
|
||||
ciphertext = encrypt(test_plaintext)
|
||||
decrypted = decrypt(ciphertext)
|
||||
if decrypted == test_plaintext:
|
||||
print("\ncrypto: ok")
|
||||
else:
|
||||
print("\ncrypto: FAILED (round-trip mismatch)")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\ncrypto: FAILED ({e})")
|
||||
return 1
|
||||
|
||||
print("\nAll checks passed.")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
await store.close()
|
||||
|
||||
|
||||
def main_config_store_check() -> None:
|
||||
"""Entry point for central-cli config-store-check."""
|
||||
sys.exit(asyncio.run(config_store_check()))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(description="Central CLI")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("config-store-check", help="Test config store connectivity")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "config-store-check":
|
||||
main_config_store_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""Central CLI commands."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
|
||||
async def config_store_check() -> int:
|
||||
"""Smoke test for config store connectivity.
|
||||
|
||||
Connects via bootstrap_config, lists adapters, and verifies crypto.
|
||||
Returns 0 on success, 1 on failure.
|
||||
"""
|
||||
from central.bootstrap_config import get_settings
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
settings = get_settings()
|
||||
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
|
||||
|
||||
try:
|
||||
store = await ConfigStore.create(settings.db_dsn)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to connect to database: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# List adapters
|
||||
adapters = await store.list_adapters()
|
||||
print(f"\nAdapters ({len(adapters)}):")
|
||||
for adapter in adapters:
|
||||
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
|
||||
print(f" settings: {adapter.settings}")
|
||||
|
||||
# Test crypto
|
||||
test_plaintext = b"config_store_check_test"
|
||||
try:
|
||||
ciphertext = encrypt(test_plaintext)
|
||||
decrypted = decrypt(ciphertext)
|
||||
if decrypted == test_plaintext:
|
||||
print("\ncrypto: ok")
|
||||
else:
|
||||
print("\ncrypto: FAILED (round-trip mismatch)")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\ncrypto: FAILED ({e})")
|
||||
return 1
|
||||
|
||||
print("\nAll checks passed.")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
await store.close()
|
||||
|
||||
|
||||
def main_config_store_check() -> None:
|
||||
"""Entry point for central-cli config-store-check."""
|
||||
sys.exit(asyncio.run(config_store_check()))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(description="Central CLI")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("config-store-check", help="Test config store connectivity")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "config-store-check":
|
||||
main_config_store_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,332 +1,332 @@
|
|||
"""Database-backed configuration store.
|
||||
|
||||
Provides async access to the config schema tables with support for
|
||||
Postgres LISTEN/NOTIFY for real-time config change notifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from central.config_models import AdapterConfig, StreamConfig
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
|
||||
"""Set up JSON codec for asyncpg connection."""
|
||||
await conn.set_type_codec(
|
||||
"jsonb",
|
||||
encoder=json.dumps,
|
||||
decoder=json.loads,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
|
||||
class ConfigStore:
|
||||
"""Async interface to the config schema in Postgres."""
|
||||
|
||||
def __init__(self, pool: asyncpg.Pool) -> None:
|
||||
self._pool = pool
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
|
||||
"""Create a ConfigStore with a new connection pool."""
|
||||
pool = await asyncpg.create_pool(
|
||||
dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
init=_setup_json_codec,
|
||||
)
|
||||
return cls(pool)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool."""
|
||||
await self._pool.close()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Adapter configuration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get configuration for a specific adapter."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return AdapterConfig(**dict(row))
|
||||
|
||||
async def list_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all configured adapters."""
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
return [AdapterConfig(**dict(row)) for row in rows]
|
||||
|
||||
async def upsert_adapter(
|
||||
self,
|
||||
name: str,
|
||||
enabled: bool,
|
||||
cadence_s: int,
|
||||
settings: dict[str, Any],
|
||||
) -> None:
|
||||
"""Insert or update an adapter configuration."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
|
||||
VALUES ($1, $2, $3, $4, now())
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
enabled = EXCLUDED.enabled,
|
||||
cadence_s = EXCLUDED.cadence_s,
|
||||
settings = EXCLUDED.settings,
|
||||
updated_at = now()
|
||||
""",
|
||||
name,
|
||||
enabled,
|
||||
cadence_s,
|
||||
settings, # Will be encoded as JSON by the codec
|
||||
)
|
||||
|
||||
async def pause_adapter(self, name: str) -> None:
|
||||
"""Pause an adapter by setting paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now(), updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
async def unpause_adapter(self, name: str) -> None:
|
||||
"""Unpause an adapter by clearing paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = NULL, updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stream configuration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_stream(self, name: str) -> StreamConfig | None:
|
||||
"""Get configuration for a specific stream."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
|
||||
FROM config.streams
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return StreamConfig(**dict(row))
|
||||
|
||||
async def list_streams(self) -> list[StreamConfig]:
|
||||
"""List all configured streams."""
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
|
||||
FROM config.streams
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
return [StreamConfig(**dict(row)) for row in rows]
|
||||
|
||||
async def upsert_stream(self, name: str, max_age_s: int) -> None:
|
||||
"""Insert or update a stream's max_age_s (operator-facing)."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.streams (name, max_age_s, updated_at)
|
||||
VALUES ($1, $2, now())
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
max_age_s = EXCLUDED.max_age_s,
|
||||
updated_at = now()
|
||||
""",
|
||||
name,
|
||||
max_age_s,
|
||||
)
|
||||
|
||||
async def update_stream_max_bytes(self, name: str, max_bytes: int) -> None:
|
||||
"""Update a stream's max_bytes (supervisor-internal).
|
||||
|
||||
This update only touches max_bytes, which does NOT trigger
|
||||
the column-filtered NOTIFY (only max_age_s changes fire NOTIFY).
|
||||
"""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.streams
|
||||
SET max_bytes = $2, updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
max_bytes,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API key management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
|
||||
"""Store an API key, encrypting it with the master key."""
|
||||
encrypted = encrypt(plaintext_value.encode("utf-8"))
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.api_keys (alias, encrypted_value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (alias) DO UPDATE SET
|
||||
encrypted_value = EXCLUDED.encrypted_value,
|
||||
rotated_at = now()
|
||||
""",
|
||||
alias,
|
||||
encrypted,
|
||||
)
|
||||
|
||||
async def get_api_key(self, alias: str) -> str | None:
|
||||
"""Retrieve and decrypt an API key by alias."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is not None:
|
||||
# Update last_used_at
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return decrypt(row["encrypted_value"]).decode("utf-8")
|
||||
|
||||
async def delete_api_key(self, alias: str) -> bool:
|
||||
"""Delete an API key. Returns True if key existed."""
|
||||
async with self._pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM config.api_keys WHERE alias = $1", alias
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Change notifications
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def listen_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""Listen for config changes via Postgres NOTIFY.
|
||||
|
||||
Runs forever, calling callback(table, key) each time a change is
|
||||
detected. The callback can be sync or async.
|
||||
|
||||
On connection loss, automatically reconnects with exponential backoff.
|
||||
Cancellation (via task.cancel()) propagates cleanly.
|
||||
|
||||
Args:
|
||||
callback: Function called with (table_name, row_key) on each change.
|
||||
"""
|
||||
backoff = 1.0
|
||||
max_backoff = 30.0
|
||||
|
||||
while True:
|
||||
conn = None
|
||||
try:
|
||||
conn = await self._pool.acquire()
|
||||
logger.info("Config listener connected to database")
|
||||
backoff = 1.0 # Reset backoff on successful connect
|
||||
|
||||
def notification_handler(
|
||||
conn: asyncpg.Connection,
|
||||
pid: int,
|
||||
channel: str,
|
||||
payload: str,
|
||||
) -> None:
|
||||
# payload format: "table_name:key"
|
||||
if ":" in payload:
|
||||
table, key = payload.split(":", 1)
|
||||
else:
|
||||
table, key = payload, ""
|
||||
|
||||
result = callback(table, key)
|
||||
if asyncio.iscoroutine(result):
|
||||
asyncio.create_task(result)
|
||||
|
||||
await conn.add_listener("config_changed", notification_handler)
|
||||
|
||||
try:
|
||||
# Keep connection alive with periodic keepalive
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
await conn.execute("SELECT 1")
|
||||
finally:
|
||||
await conn.remove_listener("config_changed", notification_handler)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Cancellation must propagate cleanly
|
||||
logger.info("Config listener cancelled")
|
||||
raise
|
||||
|
||||
except (
|
||||
asyncpg.PostgresConnectionError,
|
||||
asyncpg.InterfaceError,
|
||||
ConnectionResetError,
|
||||
OSError,
|
||||
) as e:
|
||||
logger.warning(
|
||||
"Config listener connection lost, reconnecting in %.1fs: %s",
|
||||
backoff,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected error - log and retry with backoff
|
||||
logger.exception(
|
||||
"Config listener unexpected error, reconnecting in %.1fs",
|
||||
backoff,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
await self._pool.release(conn)
|
||||
except Exception:
|
||||
pass # Connection may already be invalid
|
||||
"""Database-backed configuration store.
|
||||
|
||||
Provides async access to the config schema tables with support for
|
||||
Postgres LISTEN/NOTIFY for real-time config change notifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from central.config_models import AdapterConfig, StreamConfig
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
|
||||
"""Set up JSON codec for asyncpg connection."""
|
||||
await conn.set_type_codec(
|
||||
"jsonb",
|
||||
encoder=json.dumps,
|
||||
decoder=json.loads,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
|
||||
class ConfigStore:
|
||||
"""Async interface to the config schema in Postgres."""
|
||||
|
||||
def __init__(self, pool: asyncpg.Pool) -> None:
|
||||
self._pool = pool
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
|
||||
"""Create a ConfigStore with a new connection pool."""
|
||||
pool = await asyncpg.create_pool(
|
||||
dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
init=_setup_json_codec,
|
||||
)
|
||||
return cls(pool)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool."""
|
||||
await self._pool.close()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Adapter configuration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get configuration for a specific adapter."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return AdapterConfig(**dict(row))
|
||||
|
||||
async def list_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all configured adapters."""
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
return [AdapterConfig(**dict(row)) for row in rows]
|
||||
|
||||
async def upsert_adapter(
|
||||
self,
|
||||
name: str,
|
||||
enabled: bool,
|
||||
cadence_s: int,
|
||||
settings: dict[str, Any],
|
||||
) -> None:
|
||||
"""Insert or update an adapter configuration."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
|
||||
VALUES ($1, $2, $3, $4, now())
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
enabled = EXCLUDED.enabled,
|
||||
cadence_s = EXCLUDED.cadence_s,
|
||||
settings = EXCLUDED.settings,
|
||||
updated_at = now()
|
||||
""",
|
||||
name,
|
||||
enabled,
|
||||
cadence_s,
|
||||
settings, # Will be encoded as JSON by the codec
|
||||
)
|
||||
|
||||
async def pause_adapter(self, name: str) -> None:
|
||||
"""Pause an adapter by setting paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now(), updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
async def unpause_adapter(self, name: str) -> None:
|
||||
"""Unpause an adapter by clearing paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = NULL, updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stream configuration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_stream(self, name: str) -> StreamConfig | None:
|
||||
"""Get configuration for a specific stream."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
|
||||
FROM config.streams
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return StreamConfig(**dict(row))
|
||||
|
||||
async def list_streams(self) -> list[StreamConfig]:
|
||||
"""List all configured streams."""
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
|
||||
FROM config.streams
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
return [StreamConfig(**dict(row)) for row in rows]
|
||||
|
||||
async def upsert_stream(self, name: str, max_age_s: int) -> None:
|
||||
"""Insert or update a stream's max_age_s (operator-facing)."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.streams (name, max_age_s, updated_at)
|
||||
VALUES ($1, $2, now())
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
max_age_s = EXCLUDED.max_age_s,
|
||||
updated_at = now()
|
||||
""",
|
||||
name,
|
||||
max_age_s,
|
||||
)
|
||||
|
||||
async def update_stream_max_bytes(self, name: str, max_bytes: int) -> None:
|
||||
"""Update a stream's max_bytes (supervisor-internal).
|
||||
|
||||
This update only touches max_bytes, which does NOT trigger
|
||||
the column-filtered NOTIFY (only max_age_s changes fire NOTIFY).
|
||||
"""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.streams
|
||||
SET max_bytes = $2, updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
max_bytes,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API key management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
|
||||
"""Store an API key, encrypting it with the master key."""
|
||||
encrypted = encrypt(plaintext_value.encode("utf-8"))
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.api_keys (alias, encrypted_value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (alias) DO UPDATE SET
|
||||
encrypted_value = EXCLUDED.encrypted_value,
|
||||
rotated_at = now()
|
||||
""",
|
||||
alias,
|
||||
encrypted,
|
||||
)
|
||||
|
||||
async def get_api_key(self, alias: str) -> str | None:
|
||||
"""Retrieve and decrypt an API key by alias."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is not None:
|
||||
# Update last_used_at
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return decrypt(row["encrypted_value"]).decode("utf-8")
|
||||
|
||||
async def delete_api_key(self, alias: str) -> bool:
|
||||
"""Delete an API key. Returns True if key existed."""
|
||||
async with self._pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM config.api_keys WHERE alias = $1", alias
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Change notifications
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def listen_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""Listen for config changes via Postgres NOTIFY.
|
||||
|
||||
Runs forever, calling callback(table, key) each time a change is
|
||||
detected. The callback can be sync or async.
|
||||
|
||||
On connection loss, automatically reconnects with exponential backoff.
|
||||
Cancellation (via task.cancel()) propagates cleanly.
|
||||
|
||||
Args:
|
||||
callback: Function called with (table_name, row_key) on each change.
|
||||
"""
|
||||
backoff = 1.0
|
||||
max_backoff = 30.0
|
||||
|
||||
while True:
|
||||
conn = None
|
||||
try:
|
||||
conn = await self._pool.acquire()
|
||||
logger.info("Config listener connected to database")
|
||||
backoff = 1.0 # Reset backoff on successful connect
|
||||
|
||||
def notification_handler(
|
||||
conn: asyncpg.Connection,
|
||||
pid: int,
|
||||
channel: str,
|
||||
payload: str,
|
||||
) -> None:
|
||||
# payload format: "table_name:key"
|
||||
if ":" in payload:
|
||||
table, key = payload.split(":", 1)
|
||||
else:
|
||||
table, key = payload, ""
|
||||
|
||||
result = callback(table, key)
|
||||
if asyncio.iscoroutine(result):
|
||||
asyncio.create_task(result)
|
||||
|
||||
await conn.add_listener("config_changed", notification_handler)
|
||||
|
||||
try:
|
||||
# Keep connection alive with periodic keepalive
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
await conn.execute("SELECT 1")
|
||||
finally:
|
||||
await conn.remove_listener("config_changed", notification_handler)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Cancellation must propagate cleanly
|
||||
logger.info("Config listener cancelled")
|
||||
raise
|
||||
|
||||
except (
|
||||
asyncpg.PostgresConnectionError,
|
||||
asyncpg.InterfaceError,
|
||||
ConnectionResetError,
|
||||
OSError,
|
||||
) as e:
|
||||
logger.warning(
|
||||
"Config listener connection lost, reconnecting in %.1fs: %s",
|
||||
backoff,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected error - log and retry with backoff
|
||||
logger.exception(
|
||||
"Config listener unexpected error, reconnecting in %.1fs",
|
||||
backoff,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
await self._pool.release(conn)
|
||||
except Exception:
|
||||
pass # Connection may already be invalid
|
||||
|
|
|
|||
|
|
@ -1,111 +1,111 @@
|
|||
"""Cryptographic primitives for secret storage.
|
||||
|
||||
Uses AES-256-GCM for authenticated encryption. The master key is read
|
||||
from the path specified in bootstrap config on first use and cached.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
# AES-256 requires 32-byte key
|
||||
KEY_SIZE = 32
|
||||
# GCM nonce size (96 bits recommended by NIST)
|
||||
NONCE_SIZE = 12
|
||||
|
||||
|
||||
class CryptoError(Exception):
|
||||
"""Base exception for crypto operations."""
|
||||
|
||||
|
||||
class KeyLoadError(CryptoError):
|
||||
"""Failed to load master key."""
|
||||
|
||||
|
||||
class DecryptionError(CryptoError):
|
||||
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _load_master_key(path: Path) -> bytes:
|
||||
"""Load and decode the base64-encoded master key from file."""
|
||||
try:
|
||||
key_b64 = path.read_text().strip()
|
||||
key = base64.b64decode(key_b64)
|
||||
except FileNotFoundError:
|
||||
raise KeyLoadError(f"Master key file not found: {path}")
|
||||
except Exception as e:
|
||||
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
|
||||
|
||||
if len(key) != KEY_SIZE:
|
||||
raise KeyLoadError(
|
||||
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Encrypt plaintext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = os.urandom(NONCE_SIZE)
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# GCM appends the 16-byte tag to the ciphertext
|
||||
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
|
||||
|
||||
return nonce + ciphertext_with_tag
|
||||
|
||||
|
||||
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Decrypt ciphertext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
ciphertext: Data in format: nonce || ciphertext || tag
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext.
|
||||
|
||||
Raises:
|
||||
DecryptionError: If decryption fails (wrong key or tampered data).
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
|
||||
raise DecryptionError("Ciphertext too short")
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = ciphertext[:NONCE_SIZE]
|
||||
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
try:
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
|
||||
except Exception as e:
|
||||
raise DecryptionError(f"Decryption failed: {e}")
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
def clear_key_cache() -> None:
|
||||
"""Clear the cached master key. Use after key rotation."""
|
||||
_load_master_key.cache_clear()
|
||||
"""Cryptographic primitives for secret storage.
|
||||
|
||||
Uses AES-256-GCM for authenticated encryption. The master key is read
|
||||
from the path specified in bootstrap config on first use and cached.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
# AES-256 requires 32-byte key
|
||||
KEY_SIZE = 32
|
||||
# GCM nonce size (96 bits recommended by NIST)
|
||||
NONCE_SIZE = 12
|
||||
|
||||
|
||||
class CryptoError(Exception):
|
||||
"""Base exception for crypto operations."""
|
||||
|
||||
|
||||
class KeyLoadError(CryptoError):
|
||||
"""Failed to load master key."""
|
||||
|
||||
|
||||
class DecryptionError(CryptoError):
|
||||
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _load_master_key(path: Path) -> bytes:
|
||||
"""Load and decode the base64-encoded master key from file."""
|
||||
try:
|
||||
key_b64 = path.read_text().strip()
|
||||
key = base64.b64decode(key_b64)
|
||||
except FileNotFoundError:
|
||||
raise KeyLoadError(f"Master key file not found: {path}")
|
||||
except Exception as e:
|
||||
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
|
||||
|
||||
if len(key) != KEY_SIZE:
|
||||
raise KeyLoadError(
|
||||
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Encrypt plaintext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = os.urandom(NONCE_SIZE)
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# GCM appends the 16-byte tag to the ciphertext
|
||||
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
|
||||
|
||||
return nonce + ciphertext_with_tag
|
||||
|
||||
|
||||
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Decrypt ciphertext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
ciphertext: Data in format: nonce || ciphertext || tag
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext.
|
||||
|
||||
Raises:
|
||||
DecryptionError: If decryption fails (wrong key or tampered data).
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
|
||||
raise DecryptionError("Ciphertext too short")
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = ciphertext[:NONCE_SIZE]
|
||||
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
try:
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
|
||||
except Exception as e:
|
||||
raise DecryptionError(f"Decryption failed: {e}")
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
def clear_key_cache() -> None:
|
||||
"""Clear the cached master key. Use after key rotation."""
|
||||
_load_master_key.cache_clear()
|
||||
|
|
|
|||
|
|
@ -1,125 +1,125 @@
|
|||
"""Simple database migration runner.
|
||||
|
||||
Tracks applied migrations in a `schema_migrations` table. Migrations are
|
||||
plain SQL files in `sql/migrations/` named with numeric prefixes:
|
||||
001_create_config_schema.sql
|
||||
002_add_operators_table.sql
|
||||
...
|
||||
|
||||
Usage:
|
||||
central-migrate [--dry-run]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
|
||||
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
|
||||
|
||||
|
||||
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
|
||||
"""Create the schema_migrations table if it doesn't exist."""
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
|
||||
"""Return set of already-applied migration versions."""
|
||||
rows = await conn.fetch("SELECT version FROM schema_migrations")
|
||||
return {row["version"] for row in rows}
|
||||
|
||||
|
||||
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
|
||||
"""Find all .sql files in migrations directory, sorted by name.
|
||||
|
||||
Returns list of (version, path) tuples where version is the filename
|
||||
without extension.
|
||||
"""
|
||||
if not migrations_dir.exists():
|
||||
return []
|
||||
|
||||
migrations = []
|
||||
for f in sorted(migrations_dir.glob("*.sql")):
|
||||
version = f.stem # e.g., "001_create_config_schema"
|
||||
migrations.append((version, f))
|
||||
return migrations
|
||||
|
||||
|
||||
async def apply_migration(
|
||||
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
|
||||
) -> None:
|
||||
"""Apply a single migration."""
|
||||
sql = sql_path.read_text()
|
||||
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] Would apply: {version}")
|
||||
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
|
||||
return
|
||||
|
||||
async with conn.transaction():
|
||||
await conn.execute(sql)
|
||||
await conn.execute(
|
||||
"INSERT INTO schema_migrations (version) VALUES ($1)", version
|
||||
)
|
||||
print(f"Applied: {version}")
|
||||
|
||||
|
||||
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
|
||||
"""Run all pending migrations.
|
||||
|
||||
Returns number of migrations applied.
|
||||
"""
|
||||
conn = await asyncpg.connect(dsn)
|
||||
try:
|
||||
await ensure_migrations_table(conn)
|
||||
applied = await get_applied_migrations(conn)
|
||||
pending = [
|
||||
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
|
||||
]
|
||||
|
||||
if not pending:
|
||||
print("No pending migrations.")
|
||||
return 0
|
||||
|
||||
print(f"Found {len(pending)} pending migration(s).")
|
||||
for version, path in pending:
|
||||
await apply_migration(conn, version, path, dry_run)
|
||||
|
||||
return len(pending)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
"""Async entry point."""
|
||||
parser = argparse.ArgumentParser(description="Run database migrations")
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be applied without executing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
|
||||
|
||||
if count > 0 and not args.dry_run:
|
||||
print(f"Successfully applied {count} migration(s).")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point."""
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""Simple database migration runner.
|
||||
|
||||
Tracks applied migrations in a `schema_migrations` table. Migrations are
|
||||
plain SQL files in `sql/migrations/` named with numeric prefixes:
|
||||
001_create_config_schema.sql
|
||||
002_add_operators_table.sql
|
||||
...
|
||||
|
||||
Usage:
|
||||
central-migrate [--dry-run]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
|
||||
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
|
||||
|
||||
|
||||
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
|
||||
"""Create the schema_migrations table if it doesn't exist."""
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
|
||||
"""Return set of already-applied migration versions."""
|
||||
rows = await conn.fetch("SELECT version FROM schema_migrations")
|
||||
return {row["version"] for row in rows}
|
||||
|
||||
|
||||
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
|
||||
"""Find all .sql files in migrations directory, sorted by name.
|
||||
|
||||
Returns list of (version, path) tuples where version is the filename
|
||||
without extension.
|
||||
"""
|
||||
if not migrations_dir.exists():
|
||||
return []
|
||||
|
||||
migrations = []
|
||||
for f in sorted(migrations_dir.glob("*.sql")):
|
||||
version = f.stem # e.g., "001_create_config_schema"
|
||||
migrations.append((version, f))
|
||||
return migrations
|
||||
|
||||
|
||||
async def apply_migration(
|
||||
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
|
||||
) -> None:
|
||||
"""Apply a single migration."""
|
||||
sql = sql_path.read_text()
|
||||
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] Would apply: {version}")
|
||||
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
|
||||
return
|
||||
|
||||
async with conn.transaction():
|
||||
await conn.execute(sql)
|
||||
await conn.execute(
|
||||
"INSERT INTO schema_migrations (version) VALUES ($1)", version
|
||||
)
|
||||
print(f"Applied: {version}")
|
||||
|
||||
|
||||
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
|
||||
"""Run all pending migrations.
|
||||
|
||||
Returns number of migrations applied.
|
||||
"""
|
||||
conn = await asyncpg.connect(dsn)
|
||||
try:
|
||||
await ensure_migrations_table(conn)
|
||||
applied = await get_applied_migrations(conn)
|
||||
pending = [
|
||||
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
|
||||
]
|
||||
|
||||
if not pending:
|
||||
print("No pending migrations.")
|
||||
return 0
|
||||
|
||||
print(f"Found {len(pending)} pending migration(s).")
|
||||
for version, path in pending:
|
||||
await apply_migration(conn, version, path, dry_run)
|
||||
|
||||
return len(pending)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
"""Async entry point."""
|
||||
parser = argparse.ArgumentParser(description="Run database migrations")
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be applied without executing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
|
||||
|
||||
if count > 0 and not args.dry_run:
|
||||
print(f"Successfully applied {count} migration(s).")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point."""
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ def subject_for_event(ev: Event) -> str:
|
|||
|
||||
Dispatch by category prefix:
|
||||
- fire.*: returns central.<category> directly
|
||||
- quake.*: returns central.<category> directly
|
||||
- wx.*: uses weather alert subject logic
|
||||
|
||||
Weather alert subjects:
|
||||
|
|
@ -49,18 +48,11 @@ def subject_for_event(ev: Event) -> str:
|
|||
|
||||
Fire hotspot subjects:
|
||||
central.fire.hotspot.<satellite>.<confidence>
|
||||
|
||||
Quake event subjects:
|
||||
central.quake.event.<magnitude_tier>
|
||||
"""
|
||||
# Fire events: subject is just central.<category>
|
||||
if ev.category.startswith("fire."):
|
||||
return f"central.{ev.category}"
|
||||
|
||||
# Quake events: subject is just central.<category>
|
||||
if ev.category.startswith("quake."):
|
||||
return f"central.{ev.category}"
|
||||
|
||||
# Weather events: use geo-based subject logic
|
||||
prefix = "central.wx"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,262 +1,262 @@
|
|||
"""JetStream stream manager for retention configuration."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nats.js import JetStreamContext
|
||||
from nats.js.api import StreamConfig, DiscardPolicy, RetentionPolicy
|
||||
|
||||
from central.config_models import StreamConfig as StreamConfigModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
ONE_GB = 1024 * 1024 * 1024 # 1 GiB in bytes
|
||||
NATS_CONFIG_PATH = Path("/etc/nats/nats-server.conf")
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Manages JetStream stream configuration and retention."""
|
||||
|
||||
def __init__(self, js: JetStreamContext) -> None:
|
||||
self._js = js
|
||||
self._server_max_file_store: int | None = None
|
||||
|
||||
async def server_max_file_store_bytes(self) -> int:
|
||||
"""Get the server's max_file_store setting in bytes.
|
||||
|
||||
Parses the NATS server config file and caches the result.
|
||||
Returns a default of 20GB if config cannot be read.
|
||||
"""
|
||||
if self._server_max_file_store is not None:
|
||||
return self._server_max_file_store
|
||||
|
||||
default_value = 20 * ONE_GB # 20GB default
|
||||
|
||||
try:
|
||||
config_text = NATS_CONFIG_PATH.read_text()
|
||||
|
||||
# Parse max_file_store value (supports GB/MB/KB suffixes)
|
||||
match = re.search(r'max_file_store:\s*(\d+)(GB|MB|KB|G|M|K)?', config_text, re.IGNORECASE)
|
||||
if match:
|
||||
value = int(match.group(1))
|
||||
suffix = (match.group(2) or "").upper()
|
||||
|
||||
if suffix in ("GB", "G"):
|
||||
value *= ONE_GB
|
||||
elif suffix in ("MB", "M"):
|
||||
value *= 1024 * 1024
|
||||
elif suffix in ("KB", "K"):
|
||||
value *= 1024
|
||||
# else: assume bytes
|
||||
|
||||
self._server_max_file_store = value
|
||||
logger.info(
|
||||
"Parsed server max_file_store",
|
||||
extra={"max_file_store_bytes": value},
|
||||
)
|
||||
return value
|
||||
|
||||
logger.warning(
|
||||
"max_file_store not found in config, using default",
|
||||
extra={"default": default_value},
|
||||
)
|
||||
self._server_max_file_store = default_value
|
||||
return default_value
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to read NATS config, using default",
|
||||
extra={"error": str(e), "default": default_value},
|
||||
)
|
||||
self._server_max_file_store = default_value
|
||||
return default_value
|
||||
|
||||
def _compute_ceiling(self, server_max: int) -> int:
|
||||
"""Compute per-stream ceiling as 30% of server max_file_store."""
|
||||
return int(server_max * 0.30)
|
||||
|
||||
async def ensure_stream(
|
||||
self,
|
||||
name: str,
|
||||
subjects: list[str],
|
||||
config: StreamConfigModel,
|
||||
) -> None:
|
||||
"""Ensure a stream exists with the given configuration.
|
||||
|
||||
Creates the stream if it doesn't exist, or updates it if it does.
|
||||
Always enforces: discard=old, max_msgs=-1 (unlimited).
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
# Clamp max_bytes to [1GB, ceiling]
|
||||
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
|
||||
|
||||
stream_config = StreamConfig(
|
||||
name=name,
|
||||
subjects=subjects,
|
||||
retention=RetentionPolicy.LIMITS,
|
||||
discard=DiscardPolicy.OLD,
|
||||
max_age=config.max_age_s,
|
||||
max_bytes=max_bytes,
|
||||
max_msgs=-1, # Unlimited messages
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get existing stream
|
||||
existing = await self._js.stream_info(name)
|
||||
|
||||
# Update if config differs
|
||||
await self._js.update_stream(config=stream_config)
|
||||
logger.info(
|
||||
"Updated stream",
|
||||
extra={
|
||||
"stream": name,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "stream not found" in str(e).lower():
|
||||
# Create new stream
|
||||
await self._js.add_stream(config=stream_config)
|
||||
logger.info(
|
||||
"Created stream",
|
||||
extra={
|
||||
"stream": name,
|
||||
"subjects": subjects,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def apply_retention(self, name: str, config: StreamConfigModel) -> None:
|
||||
"""Apply retention settings to an existing stream.
|
||||
|
||||
Updates max_age and max_bytes. Always enforces discard=old, max_msgs=-1.
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
# Clamp max_bytes to [1GB, ceiling]
|
||||
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
|
||||
|
||||
try:
|
||||
# Get current stream config
|
||||
info = await self._js.stream_info(name)
|
||||
current = info.config
|
||||
|
||||
# Build updated config
|
||||
updated = StreamConfig(
|
||||
name=name,
|
||||
subjects=current.subjects,
|
||||
retention=RetentionPolicy.LIMITS,
|
||||
discard=DiscardPolicy.OLD,
|
||||
max_age=config.max_age_s,
|
||||
max_bytes=max_bytes,
|
||||
max_msgs=-1,
|
||||
)
|
||||
|
||||
await self._js.update_stream(config=updated)
|
||||
logger.info(
|
||||
"Applied retention",
|
||||
extra={
|
||||
"stream": name,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to apply retention",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
async def recompute_max_bytes(self, name: str, max_age_s: int) -> int:
|
||||
"""Recompute max_bytes based on observed throughput.
|
||||
|
||||
Formula: rate × max_age × 1.5 safety margin, clamped to [1GB, ceiling].
|
||||
|
||||
Returns the computed max_bytes value.
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
try:
|
||||
info = await self._js.stream_info(name)
|
||||
current_bytes = info.state.bytes
|
||||
current_msgs = info.state.messages
|
||||
|
||||
# Get stream age from first message
|
||||
first_seq = info.state.first_seq
|
||||
last_seq = info.state.last_seq
|
||||
|
||||
if current_msgs == 0 or last_seq == 0:
|
||||
# No messages yet, use floor
|
||||
return ONE_GB
|
||||
|
||||
# Estimate message age span (approximation)
|
||||
# Use stream's configured max_age as the observation window
|
||||
configured_max_age = info.config.max_age
|
||||
|
||||
if configured_max_age > 0:
|
||||
# Rate = current_bytes / configured_max_age (in seconds)
|
||||
rate_per_second = current_bytes / configured_max_age
|
||||
else:
|
||||
# Fallback: assume 1 day of data
|
||||
rate_per_second = current_bytes / 86400
|
||||
|
||||
# Project bytes needed for new max_age with 1.5x safety margin
|
||||
projected = int(rate_per_second * max_age_s * 1.5)
|
||||
|
||||
# Clamp to [1GB, ceiling]
|
||||
result = max(ONE_GB, min(projected, ceiling))
|
||||
|
||||
logger.info(
|
||||
"Recomputed max_bytes",
|
||||
extra={
|
||||
"stream": name,
|
||||
"current_bytes": current_bytes,
|
||||
"rate_per_second": rate_per_second,
|
||||
"max_age_s": max_age_s,
|
||||
"projected": projected,
|
||||
"result": result,
|
||||
"ceiling": ceiling,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to recompute max_bytes, using floor",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
return ONE_GB
|
||||
|
||||
async def get_stream_stats(self, name: str) -> dict[str, Any]:
|
||||
"""Get current stream statistics for monitoring."""
|
||||
try:
|
||||
info = await self._js.stream_info(name)
|
||||
return {
|
||||
"stream": name,
|
||||
"bytes": info.state.bytes,
|
||||
"messages": info.state.messages,
|
||||
"max_bytes": info.config.max_bytes,
|
||||
"max_age_s": info.config.max_age,
|
||||
"consumers": info.state.consumer_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get stream stats",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
return {"stream": name, "error": str(e)}
|
||||
"""JetStream stream manager for retention configuration."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nats.js import JetStreamContext
|
||||
from nats.js.api import StreamConfig, DiscardPolicy, RetentionPolicy
|
||||
|
||||
from central.config_models import StreamConfig as StreamConfigModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
ONE_GB = 1024 * 1024 * 1024 # 1 GiB in bytes
|
||||
NATS_CONFIG_PATH = Path("/etc/nats/nats-server.conf")
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Manages JetStream stream configuration and retention."""
|
||||
|
||||
def __init__(self, js: JetStreamContext) -> None:
|
||||
self._js = js
|
||||
self._server_max_file_store: int | None = None
|
||||
|
||||
async def server_max_file_store_bytes(self) -> int:
|
||||
"""Get the server's max_file_store setting in bytes.
|
||||
|
||||
Parses the NATS server config file and caches the result.
|
||||
Returns a default of 20GB if config cannot be read.
|
||||
"""
|
||||
if self._server_max_file_store is not None:
|
||||
return self._server_max_file_store
|
||||
|
||||
default_value = 20 * ONE_GB # 20GB default
|
||||
|
||||
try:
|
||||
config_text = NATS_CONFIG_PATH.read_text()
|
||||
|
||||
# Parse max_file_store value (supports GB/MB/KB suffixes)
|
||||
match = re.search(r'max_file_store:\s*(\d+)(GB|MB|KB|G|M|K)?', config_text, re.IGNORECASE)
|
||||
if match:
|
||||
value = int(match.group(1))
|
||||
suffix = (match.group(2) or "").upper()
|
||||
|
||||
if suffix in ("GB", "G"):
|
||||
value *= ONE_GB
|
||||
elif suffix in ("MB", "M"):
|
||||
value *= 1024 * 1024
|
||||
elif suffix in ("KB", "K"):
|
||||
value *= 1024
|
||||
# else: assume bytes
|
||||
|
||||
self._server_max_file_store = value
|
||||
logger.info(
|
||||
"Parsed server max_file_store",
|
||||
extra={"max_file_store_bytes": value},
|
||||
)
|
||||
return value
|
||||
|
||||
logger.warning(
|
||||
"max_file_store not found in config, using default",
|
||||
extra={"default": default_value},
|
||||
)
|
||||
self._server_max_file_store = default_value
|
||||
return default_value
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to read NATS config, using default",
|
||||
extra={"error": str(e), "default": default_value},
|
||||
)
|
||||
self._server_max_file_store = default_value
|
||||
return default_value
|
||||
|
||||
def _compute_ceiling(self, server_max: int) -> int:
|
||||
"""Compute per-stream ceiling as 30% of server max_file_store."""
|
||||
return int(server_max * 0.30)
|
||||
|
||||
async def ensure_stream(
|
||||
self,
|
||||
name: str,
|
||||
subjects: list[str],
|
||||
config: StreamConfigModel,
|
||||
) -> None:
|
||||
"""Ensure a stream exists with the given configuration.
|
||||
|
||||
Creates the stream if it doesn't exist, or updates it if it does.
|
||||
Always enforces: discard=old, max_msgs=-1 (unlimited).
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
# Clamp max_bytes to [1GB, ceiling]
|
||||
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
|
||||
|
||||
stream_config = StreamConfig(
|
||||
name=name,
|
||||
subjects=subjects,
|
||||
retention=RetentionPolicy.LIMITS,
|
||||
discard=DiscardPolicy.OLD,
|
||||
max_age=config.max_age_s,
|
||||
max_bytes=max_bytes,
|
||||
max_msgs=-1, # Unlimited messages
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get existing stream
|
||||
existing = await self._js.stream_info(name)
|
||||
|
||||
# Update if config differs
|
||||
await self._js.update_stream(config=stream_config)
|
||||
logger.info(
|
||||
"Updated stream",
|
||||
extra={
|
||||
"stream": name,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "stream not found" in str(e).lower():
|
||||
# Create new stream
|
||||
await self._js.add_stream(config=stream_config)
|
||||
logger.info(
|
||||
"Created stream",
|
||||
extra={
|
||||
"stream": name,
|
||||
"subjects": subjects,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def apply_retention(self, name: str, config: StreamConfigModel) -> None:
|
||||
"""Apply retention settings to an existing stream.
|
||||
|
||||
Updates max_age and max_bytes. Always enforces discard=old, max_msgs=-1.
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
# Clamp max_bytes to [1GB, ceiling]
|
||||
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
|
||||
|
||||
try:
|
||||
# Get current stream config
|
||||
info = await self._js.stream_info(name)
|
||||
current = info.config
|
||||
|
||||
# Build updated config
|
||||
updated = StreamConfig(
|
||||
name=name,
|
||||
subjects=current.subjects,
|
||||
retention=RetentionPolicy.LIMITS,
|
||||
discard=DiscardPolicy.OLD,
|
||||
max_age=config.max_age_s,
|
||||
max_bytes=max_bytes,
|
||||
max_msgs=-1,
|
||||
)
|
||||
|
||||
await self._js.update_stream(config=updated)
|
||||
logger.info(
|
||||
"Applied retention",
|
||||
extra={
|
||||
"stream": name,
|
||||
"max_age_s": config.max_age_s,
|
||||
"max_bytes": max_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to apply retention",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
async def recompute_max_bytes(self, name: str, max_age_s: int) -> int:
|
||||
"""Recompute max_bytes based on observed throughput.
|
||||
|
||||
Formula: rate × max_age × 1.5 safety margin, clamped to [1GB, ceiling].
|
||||
|
||||
Returns the computed max_bytes value.
|
||||
"""
|
||||
server_max = await self.server_max_file_store_bytes()
|
||||
ceiling = self._compute_ceiling(server_max)
|
||||
|
||||
try:
|
||||
info = await self._js.stream_info(name)
|
||||
current_bytes = info.state.bytes
|
||||
current_msgs = info.state.messages
|
||||
|
||||
# Get stream age from first message
|
||||
first_seq = info.state.first_seq
|
||||
last_seq = info.state.last_seq
|
||||
|
||||
if current_msgs == 0 or last_seq == 0:
|
||||
# No messages yet, use floor
|
||||
return ONE_GB
|
||||
|
||||
# Estimate message age span (approximation)
|
||||
# Use stream's configured max_age as the observation window
|
||||
configured_max_age = info.config.max_age
|
||||
|
||||
if configured_max_age > 0:
|
||||
# Rate = current_bytes / configured_max_age (in seconds)
|
||||
rate_per_second = current_bytes / configured_max_age
|
||||
else:
|
||||
# Fallback: assume 1 day of data
|
||||
rate_per_second = current_bytes / 86400
|
||||
|
||||
# Project bytes needed for new max_age with 1.5x safety margin
|
||||
projected = int(rate_per_second * max_age_s * 1.5)
|
||||
|
||||
# Clamp to [1GB, ceiling]
|
||||
result = max(ONE_GB, min(projected, ceiling))
|
||||
|
||||
logger.info(
|
||||
"Recomputed max_bytes",
|
||||
extra={
|
||||
"stream": name,
|
||||
"current_bytes": current_bytes,
|
||||
"rate_per_second": rate_per_second,
|
||||
"max_age_s": max_age_s,
|
||||
"projected": projected,
|
||||
"result": result,
|
||||
"ceiling": ceiling,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to recompute max_bytes, using floor",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
return ONE_GB
|
||||
|
||||
async def get_stream_stats(self, name: str) -> dict[str, Any]:
|
||||
"""Get current stream statistics for monitoring."""
|
||||
try:
|
||||
info = await self._js.stream_info(name)
|
||||
return {
|
||||
"stream": name,
|
||||
"bytes": info.state.bytes,
|
||||
"messages": info.state.messages,
|
||||
"max_bytes": info.config.max_bytes,
|
||||
"max_age_s": info.config.max_age,
|
||||
"consumers": info.state.consumer_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get stream stats",
|
||||
extra={"stream": name, "error": str(e)},
|
||||
)
|
||||
return {"stream": name, "error": str(e)}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue