chore: normalize line endings to LF

This commit is contained in:
Matt Johnson 2026-05-16 21:27:30 +00:00
commit 374a8c067f
26 changed files with 5357 additions and 5346 deletions

View file

@ -1,430 +1,430 @@
"""FIRMS (Fire Information for Resource Management System) adapter."""
import csv
import logging
import sqlite3
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from io import StringIO
from pathlib import Path
from typing import Any
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception_type,
)
from central.adapter import SourceAdapter
from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore
from central.models import Event, Geo
logger = logging.getLogger(__name__)
# FIRMS API base URL
FIRMS_API_BASE = "https://firms.modaps.eosdis.nasa.gov/api/area/csv"
# Satellite name mapping
SATELLITE_SHORT = {
"VIIRS_SNPP_NRT": "viirs_snpp",
"VIIRS_NOAA20_NRT": "viirs_noaa20",
"VIIRS_NOAA21_NRT": "viirs_noaa21",
}
# Confidence mapping
CONFIDENCE_MAP = {
"l": "low",
"n": "nominal",
"h": "high",
}
# Severity mapping (confidence -> severity level)
SEVERITY_MAP = {
"high": 3,
"nominal": 2,
"low": 1,
}
class FIRMSAdapter(SourceAdapter):
"""NASA FIRMS fire hotspot adapter."""
name = "firms"
def __init__(
self,
config: AdapterConfig,
config_store: ConfigStore,
cursor_db_path: Path,
) -> None:
self._config_store = config_store
self._cursor_db_path = cursor_db_path
self._session: aiohttp.ClientSession | None = None
self._db: sqlite3.Connection | None = None
self._api_key: str | None = None
# Extract settings from config
self._api_key_alias: str = config.settings.get("api_key_alias", "firms")
self._satellites: list[str] = config.settings.get(
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
)
# Parse region from settings
region_dict = config.settings.get("region")
if region_dict:
self.region: RegionConfig | None = RegionConfig(**region_dict)
else:
self.region = None
async def apply_config(self, new_config: AdapterConfig) -> None:
"""Apply new configuration from hot-reload."""
old_alias = self._api_key_alias
# Update settings
self._api_key_alias = new_config.settings.get("api_key_alias", "firms")
self._satellites = new_config.settings.get(
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
)
# Update region
region_dict = new_config.settings.get("region")
if region_dict:
self.region = RegionConfig(**region_dict)
else:
self.region = None
# If API key alias changed, re-fetch the key
if self._api_key_alias != old_alias:
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if self._api_key:
logger.info("FIRMS API key reloaded", extra={"alias": self._api_key_alias})
else:
logger.warning(
"FIRMS API key not found after alias change",
extra={"alias": self._api_key_alias},
)
logger.info(
"FIRMS config applied",
extra={
"region": region_dict,
"satellites": self._satellites,
"api_key_alias": self._api_key_alias,
},
)
async def startup(self) -> None:
"""Initialize HTTP session, dedup tracker, and fetch API key."""
# Fetch API key
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if not self._api_key:
logger.error(
"FIRMS API key not found - polling will be skipped until key is set",
extra={"alias": self._api_key_alias},
)
# Initialize HTTP session
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60),
)
# Initialize dedup tracker (shared sqlite DB with NWS)
self._db = sqlite3.connect(str(self._cursor_db_path))
self._db.execute("""
CREATE TABLE IF NOT EXISTS published_ids (
adapter TEXT NOT NULL,
event_id TEXT NOT NULL,
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (adapter, event_id)
)
""")
self._db.execute("""
CREATE INDEX IF NOT EXISTS published_ids_last_seen
ON published_ids (last_seen)
""")
self._db.commit()
# Sweep old entries on startup (48h for FIRMS)
self.sweep_old_ids()
logger.info(
"FIRMS adapter started",
extra={
"region": {
"north": self.region.north,
"south": self.region.south,
"east": self.region.east,
"west": self.region.west,
} if self.region else None,
"satellites": self._satellites,
"api_key_present": self._api_key is not None,
},
)
async def shutdown(self) -> None:
"""Close HTTP session and database."""
if self._session:
await self._session.close()
self._session = None
if self._db:
self._db.close()
self._db = None
logger.info("FIRMS adapter shut down")
def is_published(self, stable_id: str) -> bool:
"""Check if an event has already been published."""
if not self._db:
return False
cur = self._db.execute(
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
(self.name, stable_id),
)
return cur.fetchone() is not None
def mark_published(self, stable_id: str) -> None:
"""Mark an event as published."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (adapter, event_id) DO UPDATE SET
last_seen = CURRENT_TIMESTAMP
""",
(self.name, stable_id),
)
self._db.commit()
def sweep_old_ids(self) -> int:
"""Remove published_ids older than 48 hours. Returns count deleted."""
if not self._db:
return 0
cur = self._db.execute(
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-48 hours')",
(self.name,),
)
self._db.commit()
count = cur.rowcount
if count > 0:
logger.info("FIRMS swept old dedup entries", extra={"count": count})
return count
def _build_stable_id(
self, satellite: str, acq_date: str, acq_time: str, lat: float, lon: float
) -> str:
"""Build stable ID for deduplication."""
# Round lat/lon to 0.001 degrees to handle floating-point comparison
lat_rounded = round(lat, 3)
lon_rounded = round(lon, 3)
return f"{satellite}:{acq_date}:{acq_time}:{lat_rounded}:{lon_rounded}"
def _build_url(self, satellite: str) -> str | None:
"""Build FIRMS API URL for a satellite."""
if not self._api_key or not self.region:
return None
# Area format: west,south,east,north
area = f"{self.region.west},{self.region.south},{self.region.east},{self.region.north}"
return f"{FIRMS_API_BASE}/{self._api_key}/{satellite}/{area}/1"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=2, max=30),
retry=retry_if_exception_type((aiohttp.ClientError,)),
reraise=True,
)
async def _fetch_csv(self, url: str) -> str:
"""Fetch CSV data from FIRMS API."""
if not self._session:
raise RuntimeError("Session not initialized")
async with self._session.get(url) as resp:
# Check for error responses
content_type = resp.headers.get("Content-Type", "")
if "text/html" in content_type:
text = await resp.text()
logger.error(
"FIRMS returned HTML (likely auth error)",
extra={"status": resp.status, "preview": text[:200]},
)
raise ValueError("FIRMS returned HTML instead of CSV")
resp.raise_for_status()
return await resp.text()
def _parse_csv(self, csv_text: str, satellite: str) -> list[dict[str, Any]]:
"""Parse FIRMS CSV response into list of dicts."""
rows = []
reader = csv.DictReader(StringIO(csv_text))
for row in reader:
try:
# Parse required fields
lat = float(row["latitude"])
lon = float(row["longitude"])
acq_date = row["acq_date"]
acq_time = row["acq_time"]
confidence_raw = row.get("confidence", "n").lower()
confidence = CONFIDENCE_MAP.get(confidence_raw, "nominal")
rows.append({
"latitude": lat,
"longitude": lon,
"bright_ti4": float(row.get("bright_ti4", 0)) if row.get("bright_ti4") else None,
"bright_ti5": float(row.get("bright_ti5", 0)) if row.get("bright_ti5") else None,
"scan": float(row.get("scan", 0)) if row.get("scan") else None,
"track": float(row.get("track", 0)) if row.get("track") else None,
"acq_date": acq_date,
"acq_time": acq_time,
"satellite": row.get("satellite", satellite),
"instrument": row.get("instrument", "VIIRS"),
"confidence": confidence,
"confidence_raw": confidence_raw,
"version": row.get("version", ""),
"frp": float(row.get("frp", 0)) if row.get("frp") else None,
"daynight": row.get("daynight", ""),
})
except (KeyError, ValueError) as e:
logger.warning(
"Failed to parse FIRMS row",
extra={"error": str(e), "row": dict(row)},
)
continue
return rows
def _row_to_event(self, row: dict[str, Any], satellite: str) -> Event:
"""Convert a parsed CSV row to an Event."""
satellite_short = SATELLITE_SHORT.get(satellite, satellite.lower().replace("_nrt", ""))
confidence = row["confidence"]
severity = SEVERITY_MAP.get(confidence, 1)
# Parse acquisition time
acq_date = row["acq_date"]
acq_time = row["acq_time"]
# acq_time is HHMM format
try:
time = datetime.strptime(
f"{acq_date} {acq_time}", "%Y-%m-%d %H%M"
).replace(tzinfo=timezone.utc)
except ValueError:
time = datetime.now(timezone.utc)
lat = row["latitude"]
lon = row["longitude"]
# Build stable ID
stable_id = self._build_stable_id(satellite, acq_date, acq_time, lat, lon)
geo = Geo(
centroid=(lon, lat), # GeoJSON order: lon, lat
bbox=(lon, lat, lon, lat), # Point bbox
regions=[],
primary_region=None,
)
return Event(
id=stable_id,
source="central/adapters/firms",
category=f"fire.hotspot.{satellite_short}.{confidence}",
time=time,
expires=None,
severity=severity,
geo=geo,
data=row,
)
async def poll(self) -> AsyncIterator[Event]:
"""Poll FIRMS API for fire hotspots."""
# Check API key
if not self._api_key:
# Try to fetch again in case it was added
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if not self._api_key:
logger.warning(
"FIRMS API key still not available, skipping poll",
extra={"alias": self._api_key_alias},
)
return
if not self.region:
logger.warning("FIRMS region not configured, skipping poll")
return
# Sweep old dedup entries periodically
self.sweep_old_ids()
total_features = 0
total_new = 0
for satellite in self._satellites:
url = self._build_url(satellite)
if not url:
continue
try:
csv_text = await self._fetch_csv(url)
rows = self._parse_csv(csv_text, satellite)
feature_count = len(rows)
total_features += feature_count
new_count = 0
for row in rows:
stable_id = self._build_stable_id(
satellite,
row["acq_date"],
row["acq_time"],
row["latitude"],
row["longitude"],
)
if self.is_published(stable_id):
continue
event = self._row_to_event(row, satellite)
yield event
self.mark_published(stable_id)
new_count += 1
total_new += new_count
logger.info(
"FIRMS satellite poll completed",
extra={
"satellite": satellite,
"feature_count": feature_count,
"new_count": new_count,
},
)
except Exception as e:
logger.error(
"FIRMS poll failed for satellite",
extra={"satellite": satellite, "error": str(e)},
)
continue
logger.info(
"FIRMS poll completed",
extra={
"total_features": total_features,
"total_new": total_new,
"satellites": self._satellites,
},
)
def subject_for_fire_hotspot(ev: Event) -> str:
"""Compute the NATS subject for a fire hotspot event.
Subject format: central.fire.hotspot.<satellite>.<confidence>
The category already contains the satellite and confidence info,
so we just prefix with 'central.'.
"""
# category is "fire.hotspot.<satellite>.<confidence>"
return f"central.{ev.category}"
"""FIRMS (Fire Information for Resource Management System) adapter."""
import csv
import logging
import sqlite3
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from io import StringIO
from pathlib import Path
from typing import Any
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception_type,
)
from central.adapter import SourceAdapter
from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore
from central.models import Event, Geo
logger = logging.getLogger(__name__)
# FIRMS API base URL
FIRMS_API_BASE = "https://firms.modaps.eosdis.nasa.gov/api/area/csv"
# Satellite name mapping
SATELLITE_SHORT = {
"VIIRS_SNPP_NRT": "viirs_snpp",
"VIIRS_NOAA20_NRT": "viirs_noaa20",
"VIIRS_NOAA21_NRT": "viirs_noaa21",
}
# Confidence mapping
CONFIDENCE_MAP = {
"l": "low",
"n": "nominal",
"h": "high",
}
# Severity mapping (confidence -> severity level)
SEVERITY_MAP = {
"high": 3,
"nominal": 2,
"low": 1,
}
class FIRMSAdapter(SourceAdapter):
"""NASA FIRMS fire hotspot adapter."""
name = "firms"
def __init__(
self,
config: AdapterConfig,
config_store: ConfigStore,
cursor_db_path: Path,
) -> None:
self._config_store = config_store
self._cursor_db_path = cursor_db_path
self._session: aiohttp.ClientSession | None = None
self._db: sqlite3.Connection | None = None
self._api_key: str | None = None
# Extract settings from config
self._api_key_alias: str = config.settings.get("api_key_alias", "firms")
self._satellites: list[str] = config.settings.get(
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
)
# Parse region from settings
region_dict = config.settings.get("region")
if region_dict:
self.region: RegionConfig | None = RegionConfig(**region_dict)
else:
self.region = None
async def apply_config(self, new_config: AdapterConfig) -> None:
"""Apply new configuration from hot-reload."""
old_alias = self._api_key_alias
# Update settings
self._api_key_alias = new_config.settings.get("api_key_alias", "firms")
self._satellites = new_config.settings.get(
"satellites", ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
)
# Update region
region_dict = new_config.settings.get("region")
if region_dict:
self.region = RegionConfig(**region_dict)
else:
self.region = None
# If API key alias changed, re-fetch the key
if self._api_key_alias != old_alias:
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if self._api_key:
logger.info("FIRMS API key reloaded", extra={"alias": self._api_key_alias})
else:
logger.warning(
"FIRMS API key not found after alias change",
extra={"alias": self._api_key_alias},
)
logger.info(
"FIRMS config applied",
extra={
"region": region_dict,
"satellites": self._satellites,
"api_key_alias": self._api_key_alias,
},
)
async def startup(self) -> None:
"""Initialize HTTP session, dedup tracker, and fetch API key."""
# Fetch API key
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if not self._api_key:
logger.error(
"FIRMS API key not found - polling will be skipped until key is set",
extra={"alias": self._api_key_alias},
)
# Initialize HTTP session
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60),
)
# Initialize dedup tracker (shared sqlite DB with NWS)
self._db = sqlite3.connect(str(self._cursor_db_path))
self._db.execute("""
CREATE TABLE IF NOT EXISTS published_ids (
adapter TEXT NOT NULL,
event_id TEXT NOT NULL,
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (adapter, event_id)
)
""")
self._db.execute("""
CREATE INDEX IF NOT EXISTS published_ids_last_seen
ON published_ids (last_seen)
""")
self._db.commit()
# Sweep old entries on startup (48h for FIRMS)
self.sweep_old_ids()
logger.info(
"FIRMS adapter started",
extra={
"region": {
"north": self.region.north,
"south": self.region.south,
"east": self.region.east,
"west": self.region.west,
} if self.region else None,
"satellites": self._satellites,
"api_key_present": self._api_key is not None,
},
)
async def shutdown(self) -> None:
"""Close HTTP session and database."""
if self._session:
await self._session.close()
self._session = None
if self._db:
self._db.close()
self._db = None
logger.info("FIRMS adapter shut down")
def is_published(self, stable_id: str) -> bool:
"""Check if an event has already been published."""
if not self._db:
return False
cur = self._db.execute(
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
(self.name, stable_id),
)
return cur.fetchone() is not None
def mark_published(self, stable_id: str) -> None:
"""Mark an event as published."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (adapter, event_id) DO UPDATE SET
last_seen = CURRENT_TIMESTAMP
""",
(self.name, stable_id),
)
self._db.commit()
def sweep_old_ids(self) -> int:
"""Remove published_ids older than 48 hours. Returns count deleted."""
if not self._db:
return 0
cur = self._db.execute(
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-48 hours')",
(self.name,),
)
self._db.commit()
count = cur.rowcount
if count > 0:
logger.info("FIRMS swept old dedup entries", extra={"count": count})
return count
def _build_stable_id(
self, satellite: str, acq_date: str, acq_time: str, lat: float, lon: float
) -> str:
"""Build stable ID for deduplication."""
# Round lat/lon to 0.001 degrees to handle floating-point comparison
lat_rounded = round(lat, 3)
lon_rounded = round(lon, 3)
return f"{satellite}:{acq_date}:{acq_time}:{lat_rounded}:{lon_rounded}"
def _build_url(self, satellite: str) -> str | None:
"""Build FIRMS API URL for a satellite."""
if not self._api_key or not self.region:
return None
# Area format: west,south,east,north
area = f"{self.region.west},{self.region.south},{self.region.east},{self.region.north}"
return f"{FIRMS_API_BASE}/{self._api_key}/{satellite}/{area}/1"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=2, max=30),
retry=retry_if_exception_type((aiohttp.ClientError,)),
reraise=True,
)
async def _fetch_csv(self, url: str) -> str:
"""Fetch CSV data from FIRMS API."""
if not self._session:
raise RuntimeError("Session not initialized")
async with self._session.get(url) as resp:
# Check for error responses
content_type = resp.headers.get("Content-Type", "")
if "text/html" in content_type:
text = await resp.text()
logger.error(
"FIRMS returned HTML (likely auth error)",
extra={"status": resp.status, "preview": text[:200]},
)
raise ValueError("FIRMS returned HTML instead of CSV")
resp.raise_for_status()
return await resp.text()
def _parse_csv(self, csv_text: str, satellite: str) -> list[dict[str, Any]]:
"""Parse FIRMS CSV response into list of dicts."""
rows = []
reader = csv.DictReader(StringIO(csv_text))
for row in reader:
try:
# Parse required fields
lat = float(row["latitude"])
lon = float(row["longitude"])
acq_date = row["acq_date"]
acq_time = row["acq_time"]
confidence_raw = row.get("confidence", "n").lower()
confidence = CONFIDENCE_MAP.get(confidence_raw, "nominal")
rows.append({
"latitude": lat,
"longitude": lon,
"bright_ti4": float(row.get("bright_ti4", 0)) if row.get("bright_ti4") else None,
"bright_ti5": float(row.get("bright_ti5", 0)) if row.get("bright_ti5") else None,
"scan": float(row.get("scan", 0)) if row.get("scan") else None,
"track": float(row.get("track", 0)) if row.get("track") else None,
"acq_date": acq_date,
"acq_time": acq_time,
"satellite": row.get("satellite", satellite),
"instrument": row.get("instrument", "VIIRS"),
"confidence": confidence,
"confidence_raw": confidence_raw,
"version": row.get("version", ""),
"frp": float(row.get("frp", 0)) if row.get("frp") else None,
"daynight": row.get("daynight", ""),
})
except (KeyError, ValueError) as e:
logger.warning(
"Failed to parse FIRMS row",
extra={"error": str(e), "row": dict(row)},
)
continue
return rows
def _row_to_event(self, row: dict[str, Any], satellite: str) -> Event:
"""Convert a parsed CSV row to an Event."""
satellite_short = SATELLITE_SHORT.get(satellite, satellite.lower().replace("_nrt", ""))
confidence = row["confidence"]
severity = SEVERITY_MAP.get(confidence, 1)
# Parse acquisition time
acq_date = row["acq_date"]
acq_time = row["acq_time"]
# acq_time is HHMM format
try:
time = datetime.strptime(
f"{acq_date} {acq_time}", "%Y-%m-%d %H%M"
).replace(tzinfo=timezone.utc)
except ValueError:
time = datetime.now(timezone.utc)
lat = row["latitude"]
lon = row["longitude"]
# Build stable ID
stable_id = self._build_stable_id(satellite, acq_date, acq_time, lat, lon)
geo = Geo(
centroid=(lon, lat), # GeoJSON order: lon, lat
bbox=(lon, lat, lon, lat), # Point bbox
regions=[],
primary_region=None,
)
return Event(
id=stable_id,
source="central/adapters/firms",
category=f"fire.hotspot.{satellite_short}.{confidence}",
time=time,
expires=None,
severity=severity,
geo=geo,
data=row,
)
async def poll(self) -> AsyncIterator[Event]:
"""Poll FIRMS API for fire hotspots."""
# Check API key
if not self._api_key:
# Try to fetch again in case it was added
self._api_key = await self._config_store.get_api_key(self._api_key_alias)
if not self._api_key:
logger.warning(
"FIRMS API key still not available, skipping poll",
extra={"alias": self._api_key_alias},
)
return
if not self.region:
logger.warning("FIRMS region not configured, skipping poll")
return
# Sweep old dedup entries periodically
self.sweep_old_ids()
total_features = 0
total_new = 0
for satellite in self._satellites:
url = self._build_url(satellite)
if not url:
continue
try:
csv_text = await self._fetch_csv(url)
rows = self._parse_csv(csv_text, satellite)
feature_count = len(rows)
total_features += feature_count
new_count = 0
for row in rows:
stable_id = self._build_stable_id(
satellite,
row["acq_date"],
row["acq_time"],
row["latitude"],
row["longitude"],
)
if self.is_published(stable_id):
continue
event = self._row_to_event(row, satellite)
yield event
self.mark_published(stable_id)
new_count += 1
total_new += new_count
logger.info(
"FIRMS satellite poll completed",
extra={
"satellite": satellite,
"feature_count": feature_count,
"new_count": new_count,
},
)
except Exception as e:
logger.error(
"FIRMS poll failed for satellite",
extra={"satellite": satellite, "error": str(e)},
)
continue
logger.info(
"FIRMS poll completed",
extra={
"total_features": total_features,
"total_new": total_new,
"satellites": self._satellites,
},
)
def subject_for_fire_hotspot(ev: Event) -> str:
"""Compute the NATS subject for a fire hotspot event.
Subject format: central.fire.hotspot.<satellite>.<confidence>
The category already contains the satellite and confidence info,
so we just prefix with 'central.'.
"""
# category is "fire.hotspot.<satellite>.<confidence>"
return f"central.{ev.category}"

View file

@ -1,400 +1,400 @@
"""USGS Earthquake Hazards Program adapter."""
import logging
import sqlite3
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiohttp
from shapely.geometry import Point, box as shapely_box
from tenacity import (
retry,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception_type,
)
from central.adapter import SourceAdapter
from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore
from central.models import Event, Geo
logger = logging.getLogger(__name__)
# USGS GeoJSON feed base URL
USGS_FEED_BASE = "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary"
# Valid feed options
VALID_FEEDS = {"all_hour", "all_day", "all_week", "all_month"}
def magnitude_tier(mag: float) -> str:
"""Classify magnitude into USGS-style tier."""
if mag < 3.0:
return "minor"
if mag < 4.0:
return "light"
if mag < 5.0:
return "moderate"
if mag < 6.0:
return "strong"
if mag < 7.0:
return "major"
return "great"
def magnitude_to_severity(mag: float) -> int:
"""Map magnitude to severity level (0-5)."""
if mag < 3.0:
return 0
if mag < 4.0:
return 1
if mag < 5.0:
return 2
if mag < 6.0:
return 3
if mag < 7.0:
return 4
return 5
class USGSQuakeAdapter(SourceAdapter):
"""USGS Earthquake Hazards Program adapter."""
name = "usgs_quake"
def __init__(
self,
config: AdapterConfig,
config_store: ConfigStore, # Unused, accepted for signature uniformity
cursor_db_path: Path,
) -> None:
self._cursor_db_path = cursor_db_path
self._session: aiohttp.ClientSession | None = None
self._db: sqlite3.Connection | None = None
# Extract settings from config
self._feed: str = config.settings.get("feed", "all_hour")
if self._feed not in VALID_FEEDS:
logger.warning(
"Invalid feed setting, using all_hour",
extra={"feed": self._feed, "valid": list(VALID_FEEDS)},
)
self._feed = "all_hour"
# Parse region from settings
region_dict = config.settings.get("region")
if region_dict:
self.region: RegionConfig | None = RegionConfig(**region_dict)
self._region_box = shapely_box(
self.region.west,
self.region.south,
self.region.east,
self.region.north,
)
else:
self.region = None
self._region_box = None
async def apply_config(self, new_config: AdapterConfig) -> None:
"""Apply new configuration from hot-reload."""
# Update feed
new_feed = new_config.settings.get("feed", "all_hour")
if new_feed in VALID_FEEDS:
self._feed = new_feed
else:
logger.warning(
"Invalid feed in new config, keeping current",
extra={"new_feed": new_feed, "current": self._feed},
)
# Update region
region_dict = new_config.settings.get("region")
if region_dict:
self.region = RegionConfig(**region_dict)
self._region_box = shapely_box(
self.region.west,
self.region.south,
self.region.east,
self.region.north,
)
else:
self.region = None
self._region_box = None
logger.info(
"USGS quake config applied",
extra={
"region": region_dict,
"feed": self._feed,
},
)
async def startup(self) -> None:
"""Initialize HTTP session and dedup tracker."""
# Initialize HTTP session
self._session = aiohttp.ClientSession(
headers={"User-Agent": "Central/1.0 (earthquake monitoring)"},
timeout=aiohttp.ClientTimeout(total=30),
)
# Initialize dedup tracker (shared sqlite DB)
self._db = sqlite3.connect(str(self._cursor_db_path))
self._db.execute("""
CREATE TABLE IF NOT EXISTS published_ids (
adapter TEXT NOT NULL,
event_id TEXT NOT NULL,
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (adapter, event_id)
)
""")
self._db.execute("""
CREATE INDEX IF NOT EXISTS published_ids_last_seen
ON published_ids (last_seen)
""")
self._db.commit()
# Sweep old entries on startup (7 days for quakes)
self.sweep_old_ids()
logger.info(
"USGS quake adapter started",
extra={
"region": {
"north": self.region.north,
"south": self.region.south,
"east": self.region.east,
"west": self.region.west,
} if self.region else None,
"feed": self._feed,
},
)
async def shutdown(self) -> None:
"""Close HTTP session and database."""
if self._session:
await self._session.close()
self._session = None
if self._db:
self._db.close()
self._db = None
logger.info("USGS quake adapter shut down")
def is_published(self, event_id: str) -> bool:
"""Check if an event has already been published."""
if not self._db:
return False
cur = self._db.execute(
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
(self.name, event_id),
)
return cur.fetchone() is not None
def mark_published(self, event_id: str) -> None:
"""Mark an event as published."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (adapter, event_id) DO UPDATE SET
last_seen = CURRENT_TIMESTAMP
""",
(self.name, event_id),
)
self._db.commit()
def sweep_old_ids(self) -> int:
"""Remove published_ids older than 7 days. Returns count deleted."""
if not self._db:
return 0
cur = self._db.execute(
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-7 days')",
(self.name,),
)
self._db.commit()
count = cur.rowcount
if count > 0:
logger.info("USGS quake swept old dedup entries", extra={"count": count})
return count
def _build_url(self) -> str:
"""Build USGS GeoJSON feed URL."""
return f"{USGS_FEED_BASE}/{self._feed}.geojson"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=1, max=15),
retry=retry_if_exception_type((aiohttp.ClientError,)),
reraise=True,
)
async def _fetch_geojson(self) -> dict[str, Any]:
"""Fetch GeoJSON data from USGS."""
if not self._session:
raise RuntimeError("Session not initialized")
url = self._build_url()
async with self._session.get(url) as resp:
resp.raise_for_status()
return await resp.json()
def _point_in_region(self, lon: float, lat: float) -> bool:
"""Check if point intersects region bbox using shapely."""
if self._region_box is None:
return True
point = Point(lon, lat)
return self._region_box.intersects(point)
def _feature_to_event(self, feature: dict[str, Any]) -> Event | None:
"""Convert a GeoJSON feature to an Event."""
props = feature.get("properties", {})
geometry = feature.get("geometry", {})
coords = geometry.get("coordinates", [])
# Validate required fields
event_id = feature.get("id")
if not event_id:
logger.warning("Feature missing id", extra={"properties": props})
return None
# Get magnitude - skip if null/missing (PM decision)
mag = props.get("mag")
if mag is None:
logger.debug(
"Skipping event with null magnitude",
extra={"id": event_id, "place": props.get("place")},
)
return None
try:
mag = float(mag)
except (TypeError, ValueError):
logger.warning(
"Invalid magnitude value",
extra={"id": event_id, "mag": mag},
)
return None
# Get coordinates [lon, lat, depth]
if len(coords) < 2:
logger.warning("Feature missing coordinates", extra={"id": event_id})
return None
lon, lat = coords[0], coords[1]
depth = coords[2] if len(coords) > 2 else None
# Region filter
if not self._point_in_region(lon, lat):
return None
# Parse event time (milliseconds since epoch)
time_ms = props.get("time")
if time_ms is not None:
try:
event_time = datetime.fromtimestamp(time_ms / 1000, tz=timezone.utc)
except (TypeError, ValueError, OSError):
event_time = datetime.now(timezone.utc)
else:
event_time = datetime.now(timezone.utc)
# Build tier and severity
tier = magnitude_tier(mag)
severity = magnitude_to_severity(mag)
# Build geo
geo = Geo(
centroid=(lon, lat),
bbox=(lon, lat, lon, lat),
regions=[],
primary_region=None,
)
# Build data payload
data = {
"magnitude": mag,
"place": props.get("place"),
"time_ms": time_ms,
"updated_ms": props.get("updated"),
"tz": props.get("tz"),
"url": props.get("url"),
"detail": props.get("detail"),
"felt": props.get("felt"),
"cdi": props.get("cdi"),
"mmi": props.get("mmi"),
"alert": props.get("alert"),
"status": props.get("status"),
"tsunami": props.get("tsunami"),
"sig": props.get("sig"),
"net": props.get("net"),
"code": props.get("code"),
"ids": props.get("ids"),
"sources": props.get("sources"),
"types": props.get("types"),
"nst": props.get("nst"),
"dmin": props.get("dmin"),
"rms": props.get("rms"),
"gap": props.get("gap"),
"magType": props.get("magType"),
"type": props.get("type"),
"title": props.get("title"),
"longitude": lon,
"latitude": lat,
"depth": depth,
}
return Event(
id=event_id,
source="central/adapters/usgs_quake",
category=f"quake.event.{tier}",
time=event_time,
expires=None,
severity=severity,
geo=geo,
data=data,
)
async def poll(self) -> AsyncIterator[Event]:
"""Poll USGS for earthquake data."""
if not self.region:
logger.warning("USGS quake region not configured, skipping poll")
return
# Sweep old dedup entries periodically
self.sweep_old_ids()
try:
data = await self._fetch_geojson()
except Exception as e:
logger.error("Failed to fetch USGS data", extra={"error": str(e)})
raise
features = data.get("features", [])
metadata = data.get("metadata", {})
logger.info(
"USGS quake poll completed",
extra={
"feature_count": len(features),
"title": metadata.get("title"),
"generated": metadata.get("generated"),
},
)
new_count = 0
for feature in features:
event = self._feature_to_event(feature)
if event is None:
continue
if self.is_published(event.id):
continue
yield event
self.mark_published(event.id)
new_count += 1
logger.info("USGS quake yielded events", extra={"count": new_count})
"""USGS Earthquake Hazards Program adapter."""
import logging
import sqlite3
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiohttp
from shapely.geometry import Point, box as shapely_box
from tenacity import (
retry,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception_type,
)
from central.adapter import SourceAdapter
from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore
from central.models import Event, Geo
logger = logging.getLogger(__name__)
# USGS GeoJSON feed base URL
USGS_FEED_BASE = "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary"
# Valid feed options
VALID_FEEDS = {"all_hour", "all_day", "all_week", "all_month"}
def magnitude_tier(mag: float) -> str:
"""Classify magnitude into USGS-style tier."""
if mag < 3.0:
return "minor"
if mag < 4.0:
return "light"
if mag < 5.0:
return "moderate"
if mag < 6.0:
return "strong"
if mag < 7.0:
return "major"
return "great"
def magnitude_to_severity(mag: float) -> int:
"""Map magnitude to severity level (0-5)."""
if mag < 3.0:
return 0
if mag < 4.0:
return 1
if mag < 5.0:
return 2
if mag < 6.0:
return 3
if mag < 7.0:
return 4
return 5
class USGSQuakeAdapter(SourceAdapter):
"""USGS Earthquake Hazards Program adapter."""
name = "usgs_quake"
def __init__(
self,
config: AdapterConfig,
config_store: ConfigStore, # Unused, accepted for signature uniformity
cursor_db_path: Path,
) -> None:
self._cursor_db_path = cursor_db_path
self._session: aiohttp.ClientSession | None = None
self._db: sqlite3.Connection | None = None
# Extract settings from config
self._feed: str = config.settings.get("feed", "all_hour")
if self._feed not in VALID_FEEDS:
logger.warning(
"Invalid feed setting, using all_hour",
extra={"feed": self._feed, "valid": list(VALID_FEEDS)},
)
self._feed = "all_hour"
# Parse region from settings
region_dict = config.settings.get("region")
if region_dict:
self.region: RegionConfig | None = RegionConfig(**region_dict)
self._region_box = shapely_box(
self.region.west,
self.region.south,
self.region.east,
self.region.north,
)
else:
self.region = None
self._region_box = None
async def apply_config(self, new_config: AdapterConfig) -> None:
"""Apply new configuration from hot-reload."""
# Update feed
new_feed = new_config.settings.get("feed", "all_hour")
if new_feed in VALID_FEEDS:
self._feed = new_feed
else:
logger.warning(
"Invalid feed in new config, keeping current",
extra={"new_feed": new_feed, "current": self._feed},
)
# Update region
region_dict = new_config.settings.get("region")
if region_dict:
self.region = RegionConfig(**region_dict)
self._region_box = shapely_box(
self.region.west,
self.region.south,
self.region.east,
self.region.north,
)
else:
self.region = None
self._region_box = None
logger.info(
"USGS quake config applied",
extra={
"region": region_dict,
"feed": self._feed,
},
)
async def startup(self) -> None:
"""Initialize HTTP session and dedup tracker."""
# Initialize HTTP session
self._session = aiohttp.ClientSession(
headers={"User-Agent": "Central/1.0 (earthquake monitoring)"},
timeout=aiohttp.ClientTimeout(total=30),
)
# Initialize dedup tracker (shared sqlite DB)
self._db = sqlite3.connect(str(self._cursor_db_path))
self._db.execute("""
CREATE TABLE IF NOT EXISTS published_ids (
adapter TEXT NOT NULL,
event_id TEXT NOT NULL,
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (adapter, event_id)
)
""")
self._db.execute("""
CREATE INDEX IF NOT EXISTS published_ids_last_seen
ON published_ids (last_seen)
""")
self._db.commit()
# Sweep old entries on startup (7 days for quakes)
self.sweep_old_ids()
logger.info(
"USGS quake adapter started",
extra={
"region": {
"north": self.region.north,
"south": self.region.south,
"east": self.region.east,
"west": self.region.west,
} if self.region else None,
"feed": self._feed,
},
)
async def shutdown(self) -> None:
"""Close HTTP session and database."""
if self._session:
await self._session.close()
self._session = None
if self._db:
self._db.close()
self._db = None
logger.info("USGS quake adapter shut down")
def is_published(self, event_id: str) -> bool:
"""Check if an event has already been published."""
if not self._db:
return False
cur = self._db.execute(
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
(self.name, event_id),
)
return cur.fetchone() is not None
def mark_published(self, event_id: str) -> None:
"""Mark an event as published."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (adapter, event_id) DO UPDATE SET
last_seen = CURRENT_TIMESTAMP
""",
(self.name, event_id),
)
self._db.commit()
def sweep_old_ids(self) -> int:
"""Remove published_ids older than 7 days. Returns count deleted."""
if not self._db:
return 0
cur = self._db.execute(
"DELETE FROM published_ids WHERE adapter = ? AND last_seen < datetime('now', '-7 days')",
(self.name,),
)
self._db.commit()
count = cur.rowcount
if count > 0:
logger.info("USGS quake swept old dedup entries", extra={"count": count})
return count
def _build_url(self) -> str:
"""Build USGS GeoJSON feed URL."""
return f"{USGS_FEED_BASE}/{self._feed}.geojson"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=1, max=15),
retry=retry_if_exception_type((aiohttp.ClientError,)),
reraise=True,
)
async def _fetch_geojson(self) -> dict[str, Any]:
"""Fetch GeoJSON data from USGS."""
if not self._session:
raise RuntimeError("Session not initialized")
url = self._build_url()
async with self._session.get(url) as resp:
resp.raise_for_status()
return await resp.json()
def _point_in_region(self, lon: float, lat: float) -> bool:
"""Check if point intersects region bbox using shapely."""
if self._region_box is None:
return True
point = Point(lon, lat)
return self._region_box.intersects(point)
def _feature_to_event(self, feature: dict[str, Any]) -> Event | None:
"""Convert a GeoJSON feature to an Event."""
props = feature.get("properties", {})
geometry = feature.get("geometry", {})
coords = geometry.get("coordinates", [])
# Validate required fields
event_id = feature.get("id")
if not event_id:
logger.warning("Feature missing id", extra={"properties": props})
return None
# Get magnitude - skip if null/missing (PM decision)
mag = props.get("mag")
if mag is None:
logger.debug(
"Skipping event with null magnitude",
extra={"id": event_id, "place": props.get("place")},
)
return None
try:
mag = float(mag)
except (TypeError, ValueError):
logger.warning(
"Invalid magnitude value",
extra={"id": event_id, "mag": mag},
)
return None
# Get coordinates [lon, lat, depth]
if len(coords) < 2:
logger.warning("Feature missing coordinates", extra={"id": event_id})
return None
lon, lat = coords[0], coords[1]
depth = coords[2] if len(coords) > 2 else None
# Region filter
if not self._point_in_region(lon, lat):
return None
# Parse event time (milliseconds since epoch)
time_ms = props.get("time")
if time_ms is not None:
try:
event_time = datetime.fromtimestamp(time_ms / 1000, tz=timezone.utc)
except (TypeError, ValueError, OSError):
event_time = datetime.now(timezone.utc)
else:
event_time = datetime.now(timezone.utc)
# Build tier and severity
tier = magnitude_tier(mag)
severity = magnitude_to_severity(mag)
# Build geo
geo = Geo(
centroid=(lon, lat),
bbox=(lon, lat, lon, lat),
regions=[],
primary_region=None,
)
# Build data payload
data = {
"magnitude": mag,
"place": props.get("place"),
"time_ms": time_ms,
"updated_ms": props.get("updated"),
"tz": props.get("tz"),
"url": props.get("url"),
"detail": props.get("detail"),
"felt": props.get("felt"),
"cdi": props.get("cdi"),
"mmi": props.get("mmi"),
"alert": props.get("alert"),
"status": props.get("status"),
"tsunami": props.get("tsunami"),
"sig": props.get("sig"),
"net": props.get("net"),
"code": props.get("code"),
"ids": props.get("ids"),
"sources": props.get("sources"),
"types": props.get("types"),
"nst": props.get("nst"),
"dmin": props.get("dmin"),
"rms": props.get("rms"),
"gap": props.get("gap"),
"magType": props.get("magType"),
"type": props.get("type"),
"title": props.get("title"),
"longitude": lon,
"latitude": lat,
"depth": depth,
}
return Event(
id=event_id,
source="central/adapters/usgs_quake",
category=f"quake.event.{tier}",
time=event_time,
expires=None,
severity=severity,
geo=geo,
data=data,
)
async def poll(self) -> AsyncIterator[Event]:
"""Poll USGS for earthquake data."""
if not self.region:
logger.warning("USGS quake region not configured, skipping poll")
return
# Sweep old dedup entries periodically
self.sweep_old_ids()
try:
data = await self._fetch_geojson()
except Exception as e:
logger.error("Failed to fetch USGS data", extra={"error": str(e)})
raise
features = data.get("features", [])
metadata = data.get("metadata", {})
logger.info(
"USGS quake poll completed",
extra={
"feature_count": len(features),
"title": metadata.get("title"),
"generated": metadata.get("generated"),
},
)
new_count = 0
for feature in features:
event = self._feature_to_event(feature)
if event is None:
continue
if self.is_published(event.id):
continue
yield event
self.mark_published(event.id)
new_count += 1
logger.info("USGS quake yielded events", extra={"count": new_count})

View file

@ -1,353 +1,353 @@
"""Central archive consumer - JetStream to TimescaleDB."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Any
import asyncpg
import nats
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
from central.bootstrap_config import get_settings
CONSUMER_NAME = "archive"
STREAM_NAME = "CENTRAL_WX"
SUBJECT_FILTER = "central.wx.>"
BATCH_SIZE = 100
FETCH_TIMEOUT = 5.0
ACK_WAIT = 30
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.archive")
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
"""Build PostGIS geometry from event geo data."""
if not geo_data:
return None
bbox = geo_data.get("bbox")
centroid = geo_data.get("centroid")
if bbox and len(bbox) == 4:
# Create polygon from bbox
min_lon, min_lat, max_lon, max_lat = bbox
return json.dumps({
"type": "Polygon",
"coordinates": [[
[min_lon, min_lat],
[max_lon, min_lat],
[max_lon, max_lat],
[min_lon, max_lat],
[min_lon, min_lat],
]]
})
elif centroid and len(centroid) == 2:
# Create point from centroid
return json.dumps({
"type": "Point",
"coordinates": centroid
})
return None
class ArchiveConsumer:
"""Archive consumer process."""
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
self._nats_url = nats_url
self._postgres_dsn = postgres_dsn
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._pool: asyncpg.Pool | None = None
self._shutdown_event = asyncio.Event()
async def connect(self) -> None:
"""Connect to NATS and PostgreSQL."""
self._nc = await nats.connect(self._nats_url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self._nats_url})
self._pool = await asyncpg.create_pool(
self._postgres_dsn,
min_size=1,
max_size=5,
)
logger.info("Connected to PostgreSQL")
async def disconnect(self) -> None:
"""Disconnect from NATS and PostgreSQL."""
if self._pool:
await self._pool.close()
self._pool = None
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected")
async def _ensure_consumer(self) -> None:
"""Ensure the durable consumer exists."""
if not self._js:
return
try:
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
except nats.js.errors.NotFoundError:
consumer_config = ConsumerConfig(
durable_name=CONSUMER_NAME,
deliver_policy=DeliverPolicy.ALL,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=ACK_WAIT,
filter_subject=SUBJECT_FILTER,
)
await self._js.add_consumer(STREAM_NAME, consumer_config)
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
"""Process a single message and insert into database."""
try:
envelope = json.loads(msg.data.decode())
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in message", extra={"error": str(e)})
await msg.ack()
return
event_data = envelope.get("data", {})
geo_data = event_data.get("geo")
event_id = envelope.get("id")
source = event_data.get("source", "")
category = event_data.get("category", "")
time_str = event_data.get("time")
expires_str = event_data.get("expires")
severity = event_data.get("severity")
regions = event_data.get("geo", {}).get("regions", [])
primary_region = event_data.get("geo", {}).get("primary_region")
# Parse timestamps
event_time = None
if time_str:
try:
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
expires_time = None
if expires_str:
try:
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
if not event_id or not event_time:
logger.warning(
"Message missing required fields",
extra={"id": event_id, "time": time_str}
)
await msg.ack()
return
geom_json = _build_geom_sql(geo_data)
try:
if geom_json:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6,
ST_GeomFromGeoJSON($7), $8, $9, $10)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
geom_json, regions, primary_region, json.dumps(envelope)
)
else:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
regions, primary_region, json.dumps(envelope)
)
await msg.ack()
logger.info("Archived event", extra={"id": event_id, "category": category})
except Exception as e:
logger.error(
"Failed to insert event",
extra={"id": event_id, "error": str(e)}
)
# Don't ack - let it be redelivered
async def _consume_loop(self) -> None:
"""Main consume loop."""
if not self._js or not self._pool:
return
await self._ensure_consumer()
sub = await self._js.pull_subscribe(
SUBJECT_FILTER,
durable=CONSUMER_NAME,
stream=STREAM_NAME,
)
logger.info(
"Subscribed to stream",
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
)
while not self._shutdown_event.is_set():
try:
msgs = await sub.fetch(
batch=BATCH_SIZE,
timeout=FETCH_TIMEOUT,
)
if msgs:
async with self._pool.acquire() as conn:
for msg in msgs:
await self._process_message(msg, conn)
except nats.errors.TimeoutError:
# No messages available, continue
pass
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Error in consume loop", extra={"error": str(e)})
await asyncio.sleep(1)
logger.info("Consume loop stopped")
async def start(self) -> None:
"""Start the consumer."""
await self.connect()
logger.info("Archive consumer ready")
async def run(self) -> None:
"""Run the consume loop until shutdown."""
await self._consume_loop()
async def stop(self) -> None:
"""Stop the consumer gracefully."""
logger.info("Archive consumer shutting down")
self._shutdown_event.set()
await self.disconnect()
logger.info("Archive consumer stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
settings = get_settings()
logger.info(
"Archive starting",
extra={
"nats_url": settings.nats_url,
},
)
consumer = ArchiveConsumer(
nats_url=settings.nats_url,
postgres_dsn=settings.db_dsn,
)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await consumer.start()
# Run consumer in background
consume_task = asyncio.create_task(consumer.run())
# Wait for shutdown signal
await shutdown_event.wait()
consumer._shutdown_event.set()
consume_task.cancel()
try:
await consume_task
except asyncio.CancelledError:
pass
await consumer.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()
"""Central archive consumer - JetStream to TimescaleDB."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Any
import asyncpg
import nats
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
from central.bootstrap_config import get_settings
CONSUMER_NAME = "archive"
STREAM_NAME = "CENTRAL_WX"
SUBJECT_FILTER = "central.wx.>"
BATCH_SIZE = 100
FETCH_TIMEOUT = 5.0
ACK_WAIT = 30
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.archive")
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
"""Build PostGIS geometry from event geo data."""
if not geo_data:
return None
bbox = geo_data.get("bbox")
centroid = geo_data.get("centroid")
if bbox and len(bbox) == 4:
# Create polygon from bbox
min_lon, min_lat, max_lon, max_lat = bbox
return json.dumps({
"type": "Polygon",
"coordinates": [[
[min_lon, min_lat],
[max_lon, min_lat],
[max_lon, max_lat],
[min_lon, max_lat],
[min_lon, min_lat],
]]
})
elif centroid and len(centroid) == 2:
# Create point from centroid
return json.dumps({
"type": "Point",
"coordinates": centroid
})
return None
class ArchiveConsumer:
"""Archive consumer process."""
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
self._nats_url = nats_url
self._postgres_dsn = postgres_dsn
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._pool: asyncpg.Pool | None = None
self._shutdown_event = asyncio.Event()
async def connect(self) -> None:
"""Connect to NATS and PostgreSQL."""
self._nc = await nats.connect(self._nats_url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self._nats_url})
self._pool = await asyncpg.create_pool(
self._postgres_dsn,
min_size=1,
max_size=5,
)
logger.info("Connected to PostgreSQL")
async def disconnect(self) -> None:
"""Disconnect from NATS and PostgreSQL."""
if self._pool:
await self._pool.close()
self._pool = None
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected")
async def _ensure_consumer(self) -> None:
"""Ensure the durable consumer exists."""
if not self._js:
return
try:
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
except nats.js.errors.NotFoundError:
consumer_config = ConsumerConfig(
durable_name=CONSUMER_NAME,
deliver_policy=DeliverPolicy.ALL,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=ACK_WAIT,
filter_subject=SUBJECT_FILTER,
)
await self._js.add_consumer(STREAM_NAME, consumer_config)
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
"""Process a single message and insert into database."""
try:
envelope = json.loads(msg.data.decode())
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in message", extra={"error": str(e)})
await msg.ack()
return
event_data = envelope.get("data", {})
geo_data = event_data.get("geo")
event_id = envelope.get("id")
source = event_data.get("source", "")
category = event_data.get("category", "")
time_str = event_data.get("time")
expires_str = event_data.get("expires")
severity = event_data.get("severity")
regions = event_data.get("geo", {}).get("regions", [])
primary_region = event_data.get("geo", {}).get("primary_region")
# Parse timestamps
event_time = None
if time_str:
try:
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
expires_time = None
if expires_str:
try:
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
if not event_id or not event_time:
logger.warning(
"Message missing required fields",
extra={"id": event_id, "time": time_str}
)
await msg.ack()
return
geom_json = _build_geom_sql(geo_data)
try:
if geom_json:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6,
ST_GeomFromGeoJSON($7), $8, $9, $10)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
geom_json, regions, primary_region, json.dumps(envelope)
)
else:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
regions, primary_region, json.dumps(envelope)
)
await msg.ack()
logger.info("Archived event", extra={"id": event_id, "category": category})
except Exception as e:
logger.error(
"Failed to insert event",
extra={"id": event_id, "error": str(e)}
)
# Don't ack - let it be redelivered
async def _consume_loop(self) -> None:
"""Main consume loop."""
if not self._js or not self._pool:
return
await self._ensure_consumer()
sub = await self._js.pull_subscribe(
SUBJECT_FILTER,
durable=CONSUMER_NAME,
stream=STREAM_NAME,
)
logger.info(
"Subscribed to stream",
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
)
while not self._shutdown_event.is_set():
try:
msgs = await sub.fetch(
batch=BATCH_SIZE,
timeout=FETCH_TIMEOUT,
)
if msgs:
async with self._pool.acquire() as conn:
for msg in msgs:
await self._process_message(msg, conn)
except nats.errors.TimeoutError:
# No messages available, continue
pass
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Error in consume loop", extra={"error": str(e)})
await asyncio.sleep(1)
logger.info("Consume loop stopped")
async def start(self) -> None:
"""Start the consumer."""
await self.connect()
logger.info("Archive consumer ready")
async def run(self) -> None:
"""Run the consume loop until shutdown."""
await self._consume_loop()
async def stop(self) -> None:
"""Stop the consumer gracefully."""
logger.info("Archive consumer shutting down")
self._shutdown_event.set()
await self.disconnect()
logger.info("Archive consumer stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
settings = get_settings()
logger.info(
"Archive starting",
extra={
"nats_url": settings.nats_url,
},
)
consumer = ArchiveConsumer(
nats_url=settings.nats_url,
postgres_dsn=settings.db_dsn,
)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await consumer.start()
# Run consumer in background
consume_task = asyncio.create_task(consumer.run())
# Wait for shutdown signal
await shutdown_event.wait()
consumer._shutdown_event.set()
consume_task.cancel()
try:
await consume_task
except asyncio.CancelledError:
pass
await consumer.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

View file

@ -1,75 +1,75 @@
"""Central CLI commands."""
import argparse
import asyncio
import sys
async def config_store_check() -> int:
"""Smoke test for config store connectivity.
Connects via bootstrap_config, lists adapters, and verifies crypto.
Returns 0 on success, 1 on failure.
"""
from central.bootstrap_config import get_settings
from central.config_store import ConfigStore
from central.crypto import decrypt, encrypt
settings = get_settings()
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
try:
store = await ConfigStore.create(settings.db_dsn)
except Exception as e:
print(f"ERROR: Failed to connect to database: {e}")
return 1
try:
# List adapters
adapters = await store.list_adapters()
print(f"\nAdapters ({len(adapters)}):")
for adapter in adapters:
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
print(f" settings: {adapter.settings}")
# Test crypto
test_plaintext = b"config_store_check_test"
try:
ciphertext = encrypt(test_plaintext)
decrypted = decrypt(ciphertext)
if decrypted == test_plaintext:
print("\ncrypto: ok")
else:
print("\ncrypto: FAILED (round-trip mismatch)")
return 1
except Exception as e:
print(f"\ncrypto: FAILED ({e})")
return 1
print("\nAll checks passed.")
return 0
finally:
await store.close()
def main_config_store_check() -> None:
"""Entry point for central-cli config-store-check."""
sys.exit(asyncio.run(config_store_check()))
def main() -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="Central CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("config-store-check", help="Test config store connectivity")
args = parser.parse_args()
if args.command == "config-store-check":
main_config_store_check()
if __name__ == "__main__":
main()
"""Central CLI commands."""
import argparse
import asyncio
import sys
async def config_store_check() -> int:
"""Smoke test for config store connectivity.
Connects via bootstrap_config, lists adapters, and verifies crypto.
Returns 0 on success, 1 on failure.
"""
from central.bootstrap_config import get_settings
from central.config_store import ConfigStore
from central.crypto import decrypt, encrypt
settings = get_settings()
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
try:
store = await ConfigStore.create(settings.db_dsn)
except Exception as e:
print(f"ERROR: Failed to connect to database: {e}")
return 1
try:
# List adapters
adapters = await store.list_adapters()
print(f"\nAdapters ({len(adapters)}):")
for adapter in adapters:
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
print(f" settings: {adapter.settings}")
# Test crypto
test_plaintext = b"config_store_check_test"
try:
ciphertext = encrypt(test_plaintext)
decrypted = decrypt(ciphertext)
if decrypted == test_plaintext:
print("\ncrypto: ok")
else:
print("\ncrypto: FAILED (round-trip mismatch)")
return 1
except Exception as e:
print(f"\ncrypto: FAILED ({e})")
return 1
print("\nAll checks passed.")
return 0
finally:
await store.close()
def main_config_store_check() -> None:
"""Entry point for central-cli config-store-check."""
sys.exit(asyncio.run(config_store_check()))
def main() -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="Central CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("config-store-check", help="Test config store connectivity")
args = parser.parse_args()
if args.command == "config-store-check":
main_config_store_check()
if __name__ == "__main__":
main()

View file

@ -1,332 +1,332 @@
"""Database-backed configuration store.
Provides async access to the config schema tables with support for
Postgres LISTEN/NOTIFY for real-time config change notifications.
"""
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
import asyncpg
from central.config_models import AdapterConfig, StreamConfig
from central.crypto import decrypt, encrypt
logger = logging.getLogger(__name__)
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
"""Set up JSON codec for asyncpg connection."""
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
class ConfigStore:
"""Async interface to the config schema in Postgres."""
def __init__(self, pool: asyncpg.Pool) -> None:
self._pool = pool
@classmethod
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
"""Create a ConfigStore with a new connection pool."""
pool = await asyncpg.create_pool(
dsn,
min_size=min_size,
max_size=max_size,
init=_setup_json_codec,
)
return cls(pool)
async def close(self) -> None:
"""Close the connection pool."""
await self._pool.close()
# -------------------------------------------------------------------------
# Adapter configuration
# -------------------------------------------------------------------------
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get configuration for a specific adapter."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
WHERE name = $1
""",
name,
)
if row is None:
return None
return AdapterConfig(**dict(row))
async def list_adapters(self) -> list[AdapterConfig]:
"""List all configured adapters."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
ORDER BY name
"""
)
return [AdapterConfig(**dict(row)) for row in rows]
async def upsert_adapter(
self,
name: str,
enabled: bool,
cadence_s: int,
settings: dict[str, Any],
) -> None:
"""Insert or update an adapter configuration."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (name) DO UPDATE SET
enabled = EXCLUDED.enabled,
cadence_s = EXCLUDED.cadence_s,
settings = EXCLUDED.settings,
updated_at = now()
""",
name,
enabled,
cadence_s,
settings, # Will be encoded as JSON by the codec
)
async def pause_adapter(self, name: str) -> None:
"""Pause an adapter by setting paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = now(), updated_at = now()
WHERE name = $1
""",
name,
)
async def unpause_adapter(self, name: str) -> None:
"""Unpause an adapter by clearing paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = NULL, updated_at = now()
WHERE name = $1
""",
name,
)
# -------------------------------------------------------------------------
# Stream configuration
# -------------------------------------------------------------------------
async def get_stream(self, name: str) -> StreamConfig | None:
"""Get configuration for a specific stream."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
FROM config.streams
WHERE name = $1
""",
name,
)
if row is None:
return None
return StreamConfig(**dict(row))
async def list_streams(self) -> list[StreamConfig]:
"""List all configured streams."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
FROM config.streams
ORDER BY name
"""
)
return [StreamConfig(**dict(row)) for row in rows]
async def upsert_stream(self, name: str, max_age_s: int) -> None:
"""Insert or update a stream's max_age_s (operator-facing)."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.streams (name, max_age_s, updated_at)
VALUES ($1, $2, now())
ON CONFLICT (name) DO UPDATE SET
max_age_s = EXCLUDED.max_age_s,
updated_at = now()
""",
name,
max_age_s,
)
async def update_stream_max_bytes(self, name: str, max_bytes: int) -> None:
"""Update a stream's max_bytes (supervisor-internal).
This update only touches max_bytes, which does NOT trigger
the column-filtered NOTIFY (only max_age_s changes fire NOTIFY).
"""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.streams
SET max_bytes = $2, updated_at = now()
WHERE name = $1
""",
name,
max_bytes,
)
# -------------------------------------------------------------------------
# API key management
# -------------------------------------------------------------------------
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
"""Store an API key, encrypting it with the master key."""
encrypted = encrypt(plaintext_value.encode("utf-8"))
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.api_keys (alias, encrypted_value)
VALUES ($1, $2)
ON CONFLICT (alias) DO UPDATE SET
encrypted_value = EXCLUDED.encrypted_value,
rotated_at = now()
""",
alias,
encrypted,
)
async def get_api_key(self, alias: str) -> str | None:
"""Retrieve and decrypt an API key by alias."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
""",
alias,
)
if row is not None:
# Update last_used_at
await conn.execute(
"""
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
""",
alias,
)
if row is None:
return None
return decrypt(row["encrypted_value"]).decode("utf-8")
async def delete_api_key(self, alias: str) -> bool:
"""Delete an API key. Returns True if key existed."""
async with self._pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM config.api_keys WHERE alias = $1", alias
)
return result == "DELETE 1"
# -------------------------------------------------------------------------
# Change notifications
# -------------------------------------------------------------------------
async def listen_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Listen for config changes via Postgres NOTIFY.
Runs forever, calling callback(table, key) each time a change is
detected. The callback can be sync or async.
On connection loss, automatically reconnects with exponential backoff.
Cancellation (via task.cancel()) propagates cleanly.
Args:
callback: Function called with (table_name, row_key) on each change.
"""
backoff = 1.0
max_backoff = 30.0
while True:
conn = None
try:
conn = await self._pool.acquire()
logger.info("Config listener connected to database")
backoff = 1.0 # Reset backoff on successful connect
def notification_handler(
conn: asyncpg.Connection,
pid: int,
channel: str,
payload: str,
) -> None:
# payload format: "table_name:key"
if ":" in payload:
table, key = payload.split(":", 1)
else:
table, key = payload, ""
result = callback(table, key)
if asyncio.iscoroutine(result):
asyncio.create_task(result)
await conn.add_listener("config_changed", notification_handler)
try:
# Keep connection alive with periodic keepalive
while True:
await asyncio.sleep(60)
await conn.execute("SELECT 1")
finally:
await conn.remove_listener("config_changed", notification_handler)
except asyncio.CancelledError:
# Cancellation must propagate cleanly
logger.info("Config listener cancelled")
raise
except (
asyncpg.PostgresConnectionError,
asyncpg.InterfaceError,
ConnectionResetError,
OSError,
) as e:
logger.warning(
"Config listener connection lost, reconnecting in %.1fs: %s",
backoff,
e,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
except Exception as e:
# Unexpected error - log and retry with backoff
logger.exception(
"Config listener unexpected error, reconnecting in %.1fs",
backoff,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
finally:
if conn is not None:
try:
await self._pool.release(conn)
except Exception:
pass # Connection may already be invalid
"""Database-backed configuration store.
Provides async access to the config schema tables with support for
Postgres LISTEN/NOTIFY for real-time config change notifications.
"""
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
import asyncpg
from central.config_models import AdapterConfig, StreamConfig
from central.crypto import decrypt, encrypt
logger = logging.getLogger(__name__)
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
"""Set up JSON codec for asyncpg connection."""
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
class ConfigStore:
"""Async interface to the config schema in Postgres."""
def __init__(self, pool: asyncpg.Pool) -> None:
self._pool = pool
@classmethod
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
"""Create a ConfigStore with a new connection pool."""
pool = await asyncpg.create_pool(
dsn,
min_size=min_size,
max_size=max_size,
init=_setup_json_codec,
)
return cls(pool)
async def close(self) -> None:
"""Close the connection pool."""
await self._pool.close()
# -------------------------------------------------------------------------
# Adapter configuration
# -------------------------------------------------------------------------
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get configuration for a specific adapter."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
WHERE name = $1
""",
name,
)
if row is None:
return None
return AdapterConfig(**dict(row))
async def list_adapters(self) -> list[AdapterConfig]:
"""List all configured adapters."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
ORDER BY name
"""
)
return [AdapterConfig(**dict(row)) for row in rows]
async def upsert_adapter(
self,
name: str,
enabled: bool,
cadence_s: int,
settings: dict[str, Any],
) -> None:
"""Insert or update an adapter configuration."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (name) DO UPDATE SET
enabled = EXCLUDED.enabled,
cadence_s = EXCLUDED.cadence_s,
settings = EXCLUDED.settings,
updated_at = now()
""",
name,
enabled,
cadence_s,
settings, # Will be encoded as JSON by the codec
)
async def pause_adapter(self, name: str) -> None:
"""Pause an adapter by setting paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = now(), updated_at = now()
WHERE name = $1
""",
name,
)
async def unpause_adapter(self, name: str) -> None:
"""Unpause an adapter by clearing paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = NULL, updated_at = now()
WHERE name = $1
""",
name,
)
# -------------------------------------------------------------------------
# Stream configuration
# -------------------------------------------------------------------------
async def get_stream(self, name: str) -> StreamConfig | None:
"""Get configuration for a specific stream."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
FROM config.streams
WHERE name = $1
""",
name,
)
if row is None:
return None
return StreamConfig(**dict(row))
async def list_streams(self) -> list[StreamConfig]:
"""List all configured streams."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, max_age_s, max_bytes, managed_max_bytes, updated_at
FROM config.streams
ORDER BY name
"""
)
return [StreamConfig(**dict(row)) for row in rows]
async def upsert_stream(self, name: str, max_age_s: int) -> None:
"""Insert or update a stream's max_age_s (operator-facing)."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.streams (name, max_age_s, updated_at)
VALUES ($1, $2, now())
ON CONFLICT (name) DO UPDATE SET
max_age_s = EXCLUDED.max_age_s,
updated_at = now()
""",
name,
max_age_s,
)
async def update_stream_max_bytes(self, name: str, max_bytes: int) -> None:
"""Update a stream's max_bytes (supervisor-internal).
This update only touches max_bytes, which does NOT trigger
the column-filtered NOTIFY (only max_age_s changes fire NOTIFY).
"""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.streams
SET max_bytes = $2, updated_at = now()
WHERE name = $1
""",
name,
max_bytes,
)
# -------------------------------------------------------------------------
# API key management
# -------------------------------------------------------------------------
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
"""Store an API key, encrypting it with the master key."""
encrypted = encrypt(plaintext_value.encode("utf-8"))
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.api_keys (alias, encrypted_value)
VALUES ($1, $2)
ON CONFLICT (alias) DO UPDATE SET
encrypted_value = EXCLUDED.encrypted_value,
rotated_at = now()
""",
alias,
encrypted,
)
async def get_api_key(self, alias: str) -> str | None:
"""Retrieve and decrypt an API key by alias."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
""",
alias,
)
if row is not None:
# Update last_used_at
await conn.execute(
"""
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
""",
alias,
)
if row is None:
return None
return decrypt(row["encrypted_value"]).decode("utf-8")
async def delete_api_key(self, alias: str) -> bool:
"""Delete an API key. Returns True if key existed."""
async with self._pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM config.api_keys WHERE alias = $1", alias
)
return result == "DELETE 1"
# -------------------------------------------------------------------------
# Change notifications
# -------------------------------------------------------------------------
async def listen_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Listen for config changes via Postgres NOTIFY.
Runs forever, calling callback(table, key) each time a change is
detected. The callback can be sync or async.
On connection loss, automatically reconnects with exponential backoff.
Cancellation (via task.cancel()) propagates cleanly.
Args:
callback: Function called with (table_name, row_key) on each change.
"""
backoff = 1.0
max_backoff = 30.0
while True:
conn = None
try:
conn = await self._pool.acquire()
logger.info("Config listener connected to database")
backoff = 1.0 # Reset backoff on successful connect
def notification_handler(
conn: asyncpg.Connection,
pid: int,
channel: str,
payload: str,
) -> None:
# payload format: "table_name:key"
if ":" in payload:
table, key = payload.split(":", 1)
else:
table, key = payload, ""
result = callback(table, key)
if asyncio.iscoroutine(result):
asyncio.create_task(result)
await conn.add_listener("config_changed", notification_handler)
try:
# Keep connection alive with periodic keepalive
while True:
await asyncio.sleep(60)
await conn.execute("SELECT 1")
finally:
await conn.remove_listener("config_changed", notification_handler)
except asyncio.CancelledError:
# Cancellation must propagate cleanly
logger.info("Config listener cancelled")
raise
except (
asyncpg.PostgresConnectionError,
asyncpg.InterfaceError,
ConnectionResetError,
OSError,
) as e:
logger.warning(
"Config listener connection lost, reconnecting in %.1fs: %s",
backoff,
e,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
except Exception as e:
# Unexpected error - log and retry with backoff
logger.exception(
"Config listener unexpected error, reconnecting in %.1fs",
backoff,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
finally:
if conn is not None:
try:
await self._pool.release(conn)
except Exception:
pass # Connection may already be invalid

View file

@ -1,111 +1,111 @@
"""Cryptographic primitives for secret storage.
Uses AES-256-GCM for authenticated encryption. The master key is read
from the path specified in bootstrap config on first use and cached.
"""
import base64
import os
from functools import lru_cache
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# AES-256 requires 32-byte key
KEY_SIZE = 32
# GCM nonce size (96 bits recommended by NIST)
NONCE_SIZE = 12
class CryptoError(Exception):
"""Base exception for crypto operations."""
class KeyLoadError(CryptoError):
"""Failed to load master key."""
class DecryptionError(CryptoError):
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
@lru_cache
def _load_master_key(path: Path) -> bytes:
"""Load and decode the base64-encoded master key from file."""
try:
key_b64 = path.read_text().strip()
key = base64.b64decode(key_b64)
except FileNotFoundError:
raise KeyLoadError(f"Master key file not found: {path}")
except Exception as e:
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
if len(key) != KEY_SIZE:
raise KeyLoadError(
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
)
return key
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
"""Encrypt plaintext using AES-256-GCM.
Args:
plaintext: Data to encrypt.
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
key = _load_master_key(key_path)
nonce = os.urandom(NONCE_SIZE)
aesgcm = AESGCM(key)
# GCM appends the 16-byte tag to the ciphertext
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
return nonce + ciphertext_with_tag
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
"""Decrypt ciphertext using AES-256-GCM.
Args:
ciphertext: Data in format: nonce || ciphertext || tag
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Decrypted plaintext.
Raises:
DecryptionError: If decryption fails (wrong key or tampered data).
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
raise DecryptionError("Ciphertext too short")
key = _load_master_key(key_path)
nonce = ciphertext[:NONCE_SIZE]
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
aesgcm = AESGCM(key)
try:
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
return plaintext
def clear_key_cache() -> None:
"""Clear the cached master key. Use after key rotation."""
_load_master_key.cache_clear()
"""Cryptographic primitives for secret storage.
Uses AES-256-GCM for authenticated encryption. The master key is read
from the path specified in bootstrap config on first use and cached.
"""
import base64
import os
from functools import lru_cache
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# AES-256 requires 32-byte key
KEY_SIZE = 32
# GCM nonce size (96 bits recommended by NIST)
NONCE_SIZE = 12
class CryptoError(Exception):
"""Base exception for crypto operations."""
class KeyLoadError(CryptoError):
"""Failed to load master key."""
class DecryptionError(CryptoError):
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
@lru_cache
def _load_master_key(path: Path) -> bytes:
"""Load and decode the base64-encoded master key from file."""
try:
key_b64 = path.read_text().strip()
key = base64.b64decode(key_b64)
except FileNotFoundError:
raise KeyLoadError(f"Master key file not found: {path}")
except Exception as e:
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
if len(key) != KEY_SIZE:
raise KeyLoadError(
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
)
return key
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
"""Encrypt plaintext using AES-256-GCM.
Args:
plaintext: Data to encrypt.
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
key = _load_master_key(key_path)
nonce = os.urandom(NONCE_SIZE)
aesgcm = AESGCM(key)
# GCM appends the 16-byte tag to the ciphertext
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
return nonce + ciphertext_with_tag
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
"""Decrypt ciphertext using AES-256-GCM.
Args:
ciphertext: Data in format: nonce || ciphertext || tag
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Decrypted plaintext.
Raises:
DecryptionError: If decryption fails (wrong key or tampered data).
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
raise DecryptionError("Ciphertext too short")
key = _load_master_key(key_path)
nonce = ciphertext[:NONCE_SIZE]
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
aesgcm = AESGCM(key)
try:
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
return plaintext
def clear_key_cache() -> None:
"""Clear the cached master key. Use after key rotation."""
_load_master_key.cache_clear()

View file

@ -1,125 +1,125 @@
"""Simple database migration runner.
Tracks applied migrations in a `schema_migrations` table. Migrations are
plain SQL files in `sql/migrations/` named with numeric prefixes:
001_create_config_schema.sql
002_add_operators_table.sql
...
Usage:
central-migrate [--dry-run]
"""
import argparse
import asyncio
import sys
from pathlib import Path
import asyncpg
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
"""Create the schema_migrations table if it doesn't exist."""
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
"""Return set of already-applied migration versions."""
rows = await conn.fetch("SELECT version FROM schema_migrations")
return {row["version"] for row in rows}
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
"""Find all .sql files in migrations directory, sorted by name.
Returns list of (version, path) tuples where version is the filename
without extension.
"""
if not migrations_dir.exists():
return []
migrations = []
for f in sorted(migrations_dir.glob("*.sql")):
version = f.stem # e.g., "001_create_config_schema"
migrations.append((version, f))
return migrations
async def apply_migration(
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
) -> None:
"""Apply a single migration."""
sql = sql_path.read_text()
if dry_run:
print(f"[DRY RUN] Would apply: {version}")
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
return
async with conn.transaction():
await conn.execute(sql)
await conn.execute(
"INSERT INTO schema_migrations (version) VALUES ($1)", version
)
print(f"Applied: {version}")
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
"""Run all pending migrations.
Returns number of migrations applied.
"""
conn = await asyncpg.connect(dsn)
try:
await ensure_migrations_table(conn)
applied = await get_applied_migrations(conn)
pending = [
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
]
if not pending:
print("No pending migrations.")
return 0
print(f"Found {len(pending)} pending migration(s).")
for version, path in pending:
await apply_migration(conn, version, path, dry_run)
return len(pending)
finally:
await conn.close()
async def async_main() -> None:
"""Async entry point."""
parser = argparse.ArgumentParser(description="Run database migrations")
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be applied without executing",
)
args = parser.parse_args()
from central.bootstrap_config import get_settings
settings = get_settings()
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
if count > 0 and not args.dry_run:
print(f"Successfully applied {count} migration(s).")
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()
"""Simple database migration runner.
Tracks applied migrations in a `schema_migrations` table. Migrations are
plain SQL files in `sql/migrations/` named with numeric prefixes:
001_create_config_schema.sql
002_add_operators_table.sql
...
Usage:
central-migrate [--dry-run]
"""
import argparse
import asyncio
import sys
from pathlib import Path
import asyncpg
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
"""Create the schema_migrations table if it doesn't exist."""
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
"""Return set of already-applied migration versions."""
rows = await conn.fetch("SELECT version FROM schema_migrations")
return {row["version"] for row in rows}
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
"""Find all .sql files in migrations directory, sorted by name.
Returns list of (version, path) tuples where version is the filename
without extension.
"""
if not migrations_dir.exists():
return []
migrations = []
for f in sorted(migrations_dir.glob("*.sql")):
version = f.stem # e.g., "001_create_config_schema"
migrations.append((version, f))
return migrations
async def apply_migration(
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
) -> None:
"""Apply a single migration."""
sql = sql_path.read_text()
if dry_run:
print(f"[DRY RUN] Would apply: {version}")
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
return
async with conn.transaction():
await conn.execute(sql)
await conn.execute(
"INSERT INTO schema_migrations (version) VALUES ($1)", version
)
print(f"Applied: {version}")
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
"""Run all pending migrations.
Returns number of migrations applied.
"""
conn = await asyncpg.connect(dsn)
try:
await ensure_migrations_table(conn)
applied = await get_applied_migrations(conn)
pending = [
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
]
if not pending:
print("No pending migrations.")
return 0
print(f"Found {len(pending)} pending migration(s).")
for version, path in pending:
await apply_migration(conn, version, path, dry_run)
return len(pending)
finally:
await conn.close()
async def async_main() -> None:
"""Async entry point."""
parser = argparse.ArgumentParser(description="Run database migrations")
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be applied without executing",
)
args = parser.parse_args()
from central.bootstrap_config import get_settings
settings = get_settings()
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
if count > 0 and not args.dry_run:
print(f"Successfully applied {count} migration(s).")
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

View file

@ -38,7 +38,6 @@ def subject_for_event(ev: Event) -> str:
Dispatch by category prefix:
- fire.*: returns central.<category> directly
- quake.*: returns central.<category> directly
- wx.*: uses weather alert subject logic
Weather alert subjects:
@ -49,18 +48,11 @@ def subject_for_event(ev: Event) -> str:
Fire hotspot subjects:
central.fire.hotspot.<satellite>.<confidence>
Quake event subjects:
central.quake.event.<magnitude_tier>
"""
# Fire events: subject is just central.<category>
if ev.category.startswith("fire."):
return f"central.{ev.category}"
# Quake events: subject is just central.<category>
if ev.category.startswith("quake."):
return f"central.{ev.category}"
# Weather events: use geo-based subject logic
prefix = "central.wx"

View file

@ -1,262 +1,262 @@
"""JetStream stream manager for retention configuration."""
import logging
import re
from pathlib import Path
from typing import Any
from nats.js import JetStreamContext
from nats.js.api import StreamConfig, DiscardPolicy, RetentionPolicy
from central.config_models import StreamConfig as StreamConfigModel
logger = logging.getLogger(__name__)
# Constants
ONE_GB = 1024 * 1024 * 1024 # 1 GiB in bytes
NATS_CONFIG_PATH = Path("/etc/nats/nats-server.conf")
class StreamManager:
"""Manages JetStream stream configuration and retention."""
def __init__(self, js: JetStreamContext) -> None:
self._js = js
self._server_max_file_store: int | None = None
async def server_max_file_store_bytes(self) -> int:
"""Get the server's max_file_store setting in bytes.
Parses the NATS server config file and caches the result.
Returns a default of 20GB if config cannot be read.
"""
if self._server_max_file_store is not None:
return self._server_max_file_store
default_value = 20 * ONE_GB # 20GB default
try:
config_text = NATS_CONFIG_PATH.read_text()
# Parse max_file_store value (supports GB/MB/KB suffixes)
match = re.search(r'max_file_store:\s*(\d+)(GB|MB|KB|G|M|K)?', config_text, re.IGNORECASE)
if match:
value = int(match.group(1))
suffix = (match.group(2) or "").upper()
if suffix in ("GB", "G"):
value *= ONE_GB
elif suffix in ("MB", "M"):
value *= 1024 * 1024
elif suffix in ("KB", "K"):
value *= 1024
# else: assume bytes
self._server_max_file_store = value
logger.info(
"Parsed server max_file_store",
extra={"max_file_store_bytes": value},
)
return value
logger.warning(
"max_file_store not found in config, using default",
extra={"default": default_value},
)
self._server_max_file_store = default_value
return default_value
except Exception as e:
logger.warning(
"Failed to read NATS config, using default",
extra={"error": str(e), "default": default_value},
)
self._server_max_file_store = default_value
return default_value
def _compute_ceiling(self, server_max: int) -> int:
"""Compute per-stream ceiling as 30% of server max_file_store."""
return int(server_max * 0.30)
async def ensure_stream(
self,
name: str,
subjects: list[str],
config: StreamConfigModel,
) -> None:
"""Ensure a stream exists with the given configuration.
Creates the stream if it doesn't exist, or updates it if it does.
Always enforces: discard=old, max_msgs=-1 (unlimited).
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
# Clamp max_bytes to [1GB, ceiling]
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
stream_config = StreamConfig(
name=name,
subjects=subjects,
retention=RetentionPolicy.LIMITS,
discard=DiscardPolicy.OLD,
max_age=config.max_age_s,
max_bytes=max_bytes,
max_msgs=-1, # Unlimited messages
)
try:
# Try to get existing stream
existing = await self._js.stream_info(name)
# Update if config differs
await self._js.update_stream(config=stream_config)
logger.info(
"Updated stream",
extra={
"stream": name,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
except Exception as e:
if "stream not found" in str(e).lower():
# Create new stream
await self._js.add_stream(config=stream_config)
logger.info(
"Created stream",
extra={
"stream": name,
"subjects": subjects,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
else:
raise
async def apply_retention(self, name: str, config: StreamConfigModel) -> None:
"""Apply retention settings to an existing stream.
Updates max_age and max_bytes. Always enforces discard=old, max_msgs=-1.
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
# Clamp max_bytes to [1GB, ceiling]
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
try:
# Get current stream config
info = await self._js.stream_info(name)
current = info.config
# Build updated config
updated = StreamConfig(
name=name,
subjects=current.subjects,
retention=RetentionPolicy.LIMITS,
discard=DiscardPolicy.OLD,
max_age=config.max_age_s,
max_bytes=max_bytes,
max_msgs=-1,
)
await self._js.update_stream(config=updated)
logger.info(
"Applied retention",
extra={
"stream": name,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
except Exception as e:
logger.error(
"Failed to apply retention",
extra={"stream": name, "error": str(e)},
)
raise
async def recompute_max_bytes(self, name: str, max_age_s: int) -> int:
"""Recompute max_bytes based on observed throughput.
Formula: rate × max_age × 1.5 safety margin, clamped to [1GB, ceiling].
Returns the computed max_bytes value.
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
try:
info = await self._js.stream_info(name)
current_bytes = info.state.bytes
current_msgs = info.state.messages
# Get stream age from first message
first_seq = info.state.first_seq
last_seq = info.state.last_seq
if current_msgs == 0 or last_seq == 0:
# No messages yet, use floor
return ONE_GB
# Estimate message age span (approximation)
# Use stream's configured max_age as the observation window
configured_max_age = info.config.max_age
if configured_max_age > 0:
# Rate = current_bytes / configured_max_age (in seconds)
rate_per_second = current_bytes / configured_max_age
else:
# Fallback: assume 1 day of data
rate_per_second = current_bytes / 86400
# Project bytes needed for new max_age with 1.5x safety margin
projected = int(rate_per_second * max_age_s * 1.5)
# Clamp to [1GB, ceiling]
result = max(ONE_GB, min(projected, ceiling))
logger.info(
"Recomputed max_bytes",
extra={
"stream": name,
"current_bytes": current_bytes,
"rate_per_second": rate_per_second,
"max_age_s": max_age_s,
"projected": projected,
"result": result,
"ceiling": ceiling,
},
)
return result
except Exception as e:
logger.error(
"Failed to recompute max_bytes, using floor",
extra={"stream": name, "error": str(e)},
)
return ONE_GB
async def get_stream_stats(self, name: str) -> dict[str, Any]:
"""Get current stream statistics for monitoring."""
try:
info = await self._js.stream_info(name)
return {
"stream": name,
"bytes": info.state.bytes,
"messages": info.state.messages,
"max_bytes": info.config.max_bytes,
"max_age_s": info.config.max_age,
"consumers": info.state.consumer_count,
}
except Exception as e:
logger.error(
"Failed to get stream stats",
extra={"stream": name, "error": str(e)},
)
return {"stream": name, "error": str(e)}
"""JetStream stream manager for retention configuration."""
import logging
import re
from pathlib import Path
from typing import Any
from nats.js import JetStreamContext
from nats.js.api import StreamConfig, DiscardPolicy, RetentionPolicy
from central.config_models import StreamConfig as StreamConfigModel
logger = logging.getLogger(__name__)
# Constants
ONE_GB = 1024 * 1024 * 1024 # 1 GiB in bytes
NATS_CONFIG_PATH = Path("/etc/nats/nats-server.conf")
class StreamManager:
"""Manages JetStream stream configuration and retention."""
def __init__(self, js: JetStreamContext) -> None:
self._js = js
self._server_max_file_store: int | None = None
async def server_max_file_store_bytes(self) -> int:
"""Get the server's max_file_store setting in bytes.
Parses the NATS server config file and caches the result.
Returns a default of 20GB if config cannot be read.
"""
if self._server_max_file_store is not None:
return self._server_max_file_store
default_value = 20 * ONE_GB # 20GB default
try:
config_text = NATS_CONFIG_PATH.read_text()
# Parse max_file_store value (supports GB/MB/KB suffixes)
match = re.search(r'max_file_store:\s*(\d+)(GB|MB|KB|G|M|K)?', config_text, re.IGNORECASE)
if match:
value = int(match.group(1))
suffix = (match.group(2) or "").upper()
if suffix in ("GB", "G"):
value *= ONE_GB
elif suffix in ("MB", "M"):
value *= 1024 * 1024
elif suffix in ("KB", "K"):
value *= 1024
# else: assume bytes
self._server_max_file_store = value
logger.info(
"Parsed server max_file_store",
extra={"max_file_store_bytes": value},
)
return value
logger.warning(
"max_file_store not found in config, using default",
extra={"default": default_value},
)
self._server_max_file_store = default_value
return default_value
except Exception as e:
logger.warning(
"Failed to read NATS config, using default",
extra={"error": str(e), "default": default_value},
)
self._server_max_file_store = default_value
return default_value
def _compute_ceiling(self, server_max: int) -> int:
"""Compute per-stream ceiling as 30% of server max_file_store."""
return int(server_max * 0.30)
async def ensure_stream(
self,
name: str,
subjects: list[str],
config: StreamConfigModel,
) -> None:
"""Ensure a stream exists with the given configuration.
Creates the stream if it doesn't exist, or updates it if it does.
Always enforces: discard=old, max_msgs=-1 (unlimited).
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
# Clamp max_bytes to [1GB, ceiling]
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
stream_config = StreamConfig(
name=name,
subjects=subjects,
retention=RetentionPolicy.LIMITS,
discard=DiscardPolicy.OLD,
max_age=config.max_age_s,
max_bytes=max_bytes,
max_msgs=-1, # Unlimited messages
)
try:
# Try to get existing stream
existing = await self._js.stream_info(name)
# Update if config differs
await self._js.update_stream(config=stream_config)
logger.info(
"Updated stream",
extra={
"stream": name,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
except Exception as e:
if "stream not found" in str(e).lower():
# Create new stream
await self._js.add_stream(config=stream_config)
logger.info(
"Created stream",
extra={
"stream": name,
"subjects": subjects,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
else:
raise
async def apply_retention(self, name: str, config: StreamConfigModel) -> None:
"""Apply retention settings to an existing stream.
Updates max_age and max_bytes. Always enforces discard=old, max_msgs=-1.
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
# Clamp max_bytes to [1GB, ceiling]
max_bytes = max(ONE_GB, min(config.max_bytes, ceiling))
try:
# Get current stream config
info = await self._js.stream_info(name)
current = info.config
# Build updated config
updated = StreamConfig(
name=name,
subjects=current.subjects,
retention=RetentionPolicy.LIMITS,
discard=DiscardPolicy.OLD,
max_age=config.max_age_s,
max_bytes=max_bytes,
max_msgs=-1,
)
await self._js.update_stream(config=updated)
logger.info(
"Applied retention",
extra={
"stream": name,
"max_age_s": config.max_age_s,
"max_bytes": max_bytes,
},
)
except Exception as e:
logger.error(
"Failed to apply retention",
extra={"stream": name, "error": str(e)},
)
raise
async def recompute_max_bytes(self, name: str, max_age_s: int) -> int:
"""Recompute max_bytes based on observed throughput.
Formula: rate × max_age × 1.5 safety margin, clamped to [1GB, ceiling].
Returns the computed max_bytes value.
"""
server_max = await self.server_max_file_store_bytes()
ceiling = self._compute_ceiling(server_max)
try:
info = await self._js.stream_info(name)
current_bytes = info.state.bytes
current_msgs = info.state.messages
# Get stream age from first message
first_seq = info.state.first_seq
last_seq = info.state.last_seq
if current_msgs == 0 or last_seq == 0:
# No messages yet, use floor
return ONE_GB
# Estimate message age span (approximation)
# Use stream's configured max_age as the observation window
configured_max_age = info.config.max_age
if configured_max_age > 0:
# Rate = current_bytes / configured_max_age (in seconds)
rate_per_second = current_bytes / configured_max_age
else:
# Fallback: assume 1 day of data
rate_per_second = current_bytes / 86400
# Project bytes needed for new max_age with 1.5x safety margin
projected = int(rate_per_second * max_age_s * 1.5)
# Clamp to [1GB, ceiling]
result = max(ONE_GB, min(projected, ceiling))
logger.info(
"Recomputed max_bytes",
extra={
"stream": name,
"current_bytes": current_bytes,
"rate_per_second": rate_per_second,
"max_age_s": max_age_s,
"projected": projected,
"result": result,
"ceiling": ceiling,
},
)
return result
except Exception as e:
logger.error(
"Failed to recompute max_bytes, using floor",
extra={"stream": name, "error": str(e)},
)
return ONE_GB
async def get_stream_stats(self, name: str) -> dict[str, Any]:
"""Get current stream statistics for monitoring."""
try:
info = await self._js.stream_info(name)
return {
"stream": name,
"bytes": info.state.bytes,
"messages": info.state.messages,
"max_bytes": info.config.max_bytes,
"max_age_s": info.config.max_age,
"consumers": info.state.consumer_count,
}
except Exception as e:
logger.error(
"Failed to get stream stats",
extra={"stream": name, "error": str(e)},
)
return {"stream": name, "error": str(e)}