From 31be17430d1ba61897da1dc0ee7ae7b07d2ef5b3 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Fri, 15 May 2026 21:29:01 +0000 Subject: [PATCH] runtime: NWS adapter, supervisor, archive consumer, systemd units Co-Authored-By: Claude Opus 4.5 --- pyproject.toml | 4 + src/central/__init__.py | 1 + src/central/adapters/nws.py | 453 +++++++++++++++++++++++++++++ src/central/archive.py | 342 ++++++++++++++++++++++ src/central/supervisor.py | 255 ++++++++++++++++ systemd/central-archive.service | 26 ++ systemd/central-supervisor.service | 26 ++ tests/test_nws_normalization.py | 373 ++++++++++++++++++++++++ 8 files changed, 1480 insertions(+) create mode 100644 src/central/adapters/nws.py create mode 100644 src/central/archive.py create mode 100644 src/central/supervisor.py create mode 100644 systemd/central-archive.service create mode 100644 systemd/central-supervisor.service create mode 100644 tests/test_nws_normalization.py diff --git a/pyproject.toml b/pyproject.toml index 63683b6..aa309dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,10 @@ dependencies = [ "tenacity>=9.1.4", ] +[project.scripts] +central-supervisor = "central.supervisor:main" +central-archive = "central.archive:main" + [tool.hatch.build.targets.wheel] packages = ["src/central"] diff --git a/src/central/__init__.py b/src/central/__init__.py index e69de29..3dc1f76 100644 --- a/src/central/__init__.py +++ b/src/central/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/src/central/adapters/nws.py b/src/central/adapters/nws.py new file mode 100644 index 0000000..bb77c4f --- /dev/null +++ b/src/central/adapters/nws.py @@ -0,0 +1,453 @@ +"""NWS (National Weather Service) alert adapter.""" + +import asyncio +import logging +import re +import sqlite3 +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import aiohttp +from aiolimiter import AsyncLimiter +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential_jitter, + retry_if_exception_type, +) + +from central import __version__ +from central.adapter import SourceAdapter +from central.config import NWSAdapterConfig +from central.models import Event, Geo + +logger = logging.getLogger(__name__) + +# FIPS state codes to postal abbreviations +FIPS_TO_STATE: dict[str, str] = { + "01": "AL", "02": "AK", "04": "AZ", "05": "AR", "06": "CA", + "08": "CO", "09": "CT", "10": "DE", "11": "DC", "12": "FL", + "13": "GA", "15": "HI", "16": "ID", "17": "IL", "18": "IN", + "19": "IA", "20": "KS", "21": "KY", "22": "LA", "23": "ME", + "24": "MD", "25": "MA", "26": "MI", "27": "MN", "28": "MS", + "29": "MO", "30": "MT", "31": "NE", "32": "NV", "33": "NH", + "34": "NJ", "35": "NM", "36": "NY", "37": "NC", "38": "ND", + "39": "OH", "40": "OK", "41": "OR", "42": "PA", "44": "RI", + "45": "SC", "46": "SD", "47": "TN", "48": "TX", "49": "UT", + "50": "VT", "51": "VA", "53": "WA", "54": "WV", "55": "WI", + "56": "WY", "60": "AS", "66": "GU", "69": "MP", "72": "PR", + "78": "VI", +} + +SEVERITY_MAP: dict[str, int | None] = { + "Extreme": 4, + "Severe": 3, + "Moderate": 2, + "Minor": 1, + "Unknown": None, +} + +NWS_API_URL = "https://api.weather.gov/alerts/active" + + +def _snake_case(s: str) -> str: + """Convert a string to snake_case.""" + s = re.sub(r"[^a-zA-Z0-9\s]", "", s) + s = re.sub(r"\s+", "_", s.strip()) + return s.lower() + + +def _parse_datetime(s: str | None) -> datetime | None: + """Parse an ISO datetime string to UTC datetime.""" + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + return dt.astimezone(timezone.utc) + except (ValueError, TypeError): + return None + + +def _compute_centroid(geometry: dict[str, Any] | None) -> tuple[float, float] | None: + """Compute centroid from GeoJSON geometry using arithmetic mean of vertices.""" + if not geometry: + return None + + geom_type = geometry.get("type") + coords = geometry.get("coordinates") + + if not coords: + return None + + all_points: list[tuple[float, float]] = [] + + if geom_type == "Point": + return (coords[0], coords[1]) + elif geom_type == "Polygon": + for ring in coords: + for point in ring: + all_points.append((point[0], point[1])) + elif geom_type == "MultiPolygon": + for polygon in coords: + for ring in polygon: + for point in ring: + all_points.append((point[0], point[1])) + else: + return None + + if not all_points: + return None + + avg_lon = sum(p[0] for p in all_points) / len(all_points) + avg_lat = sum(p[1] for p in all_points) / len(all_points) + return (avg_lon, avg_lat) + + +def _compute_bbox( + geometry: dict[str, Any] | None +) -> tuple[float, float, float, float] | None: + """Compute bounding box from GeoJSON geometry.""" + if not geometry: + return None + + geom_type = geometry.get("type") + coords = geometry.get("coordinates") + + if not coords: + return None + + all_points: list[tuple[float, float]] = [] + + if geom_type == "Point": + return (coords[0], coords[1], coords[0], coords[1]) + elif geom_type == "Polygon": + for ring in coords: + for point in ring: + all_points.append((point[0], point[1])) + elif geom_type == "MultiPolygon": + for polygon in coords: + for ring in polygon: + for point in ring: + all_points.append((point[0], point[1])) + else: + return None + + if not all_points: + return None + + min_lon = min(p[0] for p in all_points) + max_lon = max(p[0] for p in all_points) + min_lat = min(p[1] for p in all_points) + max_lat = max(p[1] for p in all_points) + return (min_lon, min_lat, max_lon, max_lat) + + +def _extract_states_from_codes( + same_codes: list[str], ugc_codes: list[str] +) -> set[str]: + """Extract state abbreviations from SAME and UGC codes.""" + states: set[str] = set() + + for code in same_codes: + if len(code) >= 2: + fips_state = code[:2] + if fips_state in FIPS_TO_STATE: + states.add(FIPS_TO_STATE[fips_state]) + + for code in ugc_codes: + if len(code) >= 2 and code[:2].isalpha(): + states.add(code[:2].upper()) + + return states + + +def _build_regions(same_codes: list[str], ugc_codes: list[str]) -> list[str]: + """Build sorted list of region strings from geocodes.""" + regions: set[str] = set() + + for code in same_codes: + if len(code) >= 2: + fips_state = code[:2] + if fips_state in FIPS_TO_STATE: + state = FIPS_TO_STATE[fips_state] + regions.add(f"US-{state}-FIPS{code}") + + for code in ugc_codes: + if len(code) >= 3 and code[:2].isalpha(): + state = code[:2].upper() + rest = code[2:] + if rest.startswith("C"): + regions.add(f"US-{state}-C{rest[1:]}") + elif rest.startswith("Z"): + regions.add(f"US-{state}-Z{rest[1:]}") + else: + regions.add(f"US-{state}-{rest}") + + return sorted(regions) + + +class NWSAdapter(SourceAdapter): + """National Weather Service alerts adapter.""" + + name = "nws" + + def __init__( + self, + config: NWSAdapterConfig, + cursor_db_path: Path, + ) -> None: + self.config = config + self.cadence_s = config.cadence_s + self.states = set(s.upper() for s in config.states) + self.cursor_db_path = cursor_db_path + self._session: aiohttp.ClientSession | None = None + self._limiter = AsyncLimiter(1, config.cadence_s) + self._db: sqlite3.Connection | None = None + + async def startup(self) -> None: + """Initialize HTTP session and cursor database.""" + user_agent = f"Central/{__version__} ({self.config.contact_email})" + self._session = aiohttp.ClientSession( + headers={"User-Agent": user_agent}, + timeout=aiohttp.ClientTimeout(total=30), + ) + + self._db = sqlite3.connect(str(self.cursor_db_path)) + self._db.execute(""" + CREATE TABLE IF NOT EXISTS adapter_cursors ( + adapter TEXT PRIMARY KEY, + cursor_data TEXT NOT NULL, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + self._db.execute(""" + CREATE TABLE IF NOT EXISTS published_ids ( + adapter TEXT NOT NULL, + event_id TEXT NOT NULL, + first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (adapter, event_id) + ) + """) + self._db.execute(""" + CREATE INDEX IF NOT EXISTS published_ids_last_seen + ON published_ids (last_seen) + """) + self._db.commit() + + logger.info("NWS adapter started", extra={"states": list(self.states)}) + + async def shutdown(self) -> None: + """Close HTTP session and database.""" + if self._session: + await self._session.close() + self._session = None + if self._db: + self._db.close() + self._db = None + logger.info("NWS adapter shut down") + + def _get_cursor(self) -> str | None: + """Get the stored If-Modified-Since cursor.""" + if not self._db: + return None + cur = self._db.execute( + "SELECT cursor_data FROM adapter_cursors WHERE adapter = ?", + (self.name,) + ) + row = cur.fetchone() + return row[0] if row else None + + def _set_cursor(self, last_modified: str) -> None: + """Store the Last-Modified header for next request.""" + if not self._db: + return + self._db.execute( + """ + INSERT INTO adapter_cursors (adapter, cursor_data, updated) + VALUES (?, ?, CURRENT_TIMESTAMP) + ON CONFLICT (adapter) DO UPDATE SET + cursor_data = excluded.cursor_data, + updated = CURRENT_TIMESTAMP + """, + (self.name, last_modified) + ) + self._db.commit() + + def is_published(self, event_id: str) -> bool: + """Check if an event has already been published.""" + if not self._db: + return False + cur = self._db.execute( + "SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?", + (self.name, event_id) + ) + return cur.fetchone() is not None + + def mark_published(self, event_id: str) -> None: + """Mark an event as published.""" + if not self._db: + return + self._db.execute( + """ + INSERT INTO published_ids (adapter, event_id, first_seen, last_seen) + VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (adapter, event_id) DO UPDATE SET + last_seen = CURRENT_TIMESTAMP + """, + (self.name, event_id) + ) + self._db.commit() + + def bump_last_seen(self, event_id: str) -> None: + """Bump the last_seen timestamp for an event.""" + if not self._db: + return + self._db.execute( + "UPDATE published_ids SET last_seen = CURRENT_TIMESTAMP WHERE adapter = ? AND event_id = ?", + (self.name, event_id) + ) + self._db.commit() + + def sweep_old_ids(self) -> int: + """Remove published_ids older than 8 days. Returns count deleted.""" + if not self._db: + return 0 + cur = self._db.execute( + "DELETE FROM published_ids WHERE last_seen < datetime('now', '-8 days')" + ) + self._db.commit() + return cur.rowcount + + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential_jitter(initial=1, max=60), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + reraise=True, + ) + async def _fetch_alerts(self) -> tuple[int, dict[str, Any] | None, str | None]: + """Fetch alerts from NWS API with conditional request.""" + async with self._limiter: + if not self._session: + raise RuntimeError("Session not initialized") + + headers: dict[str, str] = {} + cursor = self._get_cursor() + if cursor: + headers["If-Modified-Since"] = cursor + + async with self._session.get(NWS_API_URL, headers=headers) as resp: + if resp.status in (429, 403): + retry_after = resp.headers.get("Retry-After", "60") + try: + wait_time = int(retry_after) + except ValueError: + wait_time = 60 + logger.warning( + "Rate limited by NWS", + extra={"status": resp.status, "retry_after": wait_time} + ) + await asyncio.sleep(wait_time) + raise aiohttp.ClientError(f"Rate limited: {resp.status}") + + if resp.status == 304: + return (304, None, None) + + resp.raise_for_status() + + data = await resp.json() + last_modified = resp.headers.get("Last-Modified") + + return (resp.status, data, last_modified) + + def _normalize_feature(self, feature: dict[str, Any]) -> Event | None: + """Normalize a GeoJSON feature to an Event.""" + props = feature.get("properties", {}) + geocode = props.get("geocode", {}) + + same_codes = geocode.get("SAME", []) + ugc_codes = geocode.get("UGC", []) + + feature_states = _extract_states_from_codes(same_codes, ugc_codes) + if not feature_states.intersection(self.states): + return None + + event_id = feature.get("id") + if not event_id: + logger.warning("Feature missing id", extra={"properties": props}) + return None + + event_type = props.get("event", "Unknown") + category = f"wx.alert.{_snake_case(event_type)}" + + time = _parse_datetime(props.get("sent")) + if not time: + logger.warning("Feature missing sent time", extra={"id": event_id}) + return None + + expires = _parse_datetime(props.get("expires")) + + severity_str = props.get("severity", "Unknown") + severity = SEVERITY_MAP.get(severity_str) + + geometry = feature.get("geometry") + centroid = _compute_centroid(geometry) + bbox = _compute_bbox(geometry) + regions = _build_regions(same_codes, ugc_codes) + primary_region = regions[0] if regions else None + + geo = Geo( + centroid=centroid, + bbox=bbox, + regions=regions, + primary_region=primary_region, + ) + + return Event( + id=event_id, + source="central/adapters/nws", + category=category, + time=time, + expires=expires, + severity=severity, + geo=geo, + data=props, + ) + + async def poll(self) -> AsyncIterator[Event]: + """Poll NWS API for active alerts.""" + try: + status, data, last_modified = await self._fetch_alerts() + except Exception as e: + logger.error("Failed to fetch NWS alerts", extra={"error": str(e)}) + raise + + if status == 304: + logger.info("NWS returned 304 Not Modified") + return + + if last_modified: + self._set_cursor(last_modified) + + features = data.get("features", []) if data else [] + logger.info( + "NWS poll completed", + extra={"status": status, "feature_count": len(features)} + ) + + yielded = 0 + for feature in features: + try: + event = self._normalize_feature(feature) + if event: + yield event + yielded += 1 + except Exception as e: + logger.warning( + "Failed to normalize feature", + extra={"error": str(e), "feature_id": feature.get("id")} + ) + + logger.info("NWS yielded events", extra={"count": yielded}) diff --git a/src/central/archive.py b/src/central/archive.py new file mode 100644 index 0000000..1fe858a --- /dev/null +++ b/src/central/archive.py @@ -0,0 +1,342 @@ +"""Central archive consumer - JetStream to TimescaleDB.""" + +import asyncio +import json +import logging +import signal +import sys +from datetime import datetime, timezone +from typing import Any + +import asyncpg +import nats +from nats.js import JetStreamContext +from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy + +from central.config import load_config, Config + +CONFIG_PATH = "/etc/central/central.toml" +CONSUMER_NAME = "archive" +STREAM_NAME = "CENTRAL_WX" +SUBJECT_FILTER = "central.wx.>" +BATCH_SIZE = 100 +FETCH_TIMEOUT = 5.0 +ACK_WAIT = 30 + + +class JsonFormatter(logging.Formatter): + """JSON log formatter for structured logging.""" + + def format(self, record: logging.LogRecord) -> str: + log_obj: dict[str, Any] = { + "ts": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "msg": record.getMessage(), + } + if record.exc_info: + log_obj["exc"] = self.formatException(record.exc_info) + for key in record.__dict__: + if key not in ( + "name", "msg", "args", "created", "filename", "funcName", + "levelname", "levelno", "lineno", "module", "msecs", + "pathname", "process", "processName", "relativeCreated", + "stack_info", "exc_info", "exc_text", "thread", "threadName", + "taskName", "message", + ): + log_obj[key] = record.__dict__[key] + return json.dumps(log_obj) + + +def setup_logging() -> None: + """Configure JSON logging to stdout.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(JsonFormatter()) + logging.root.handlers = [handler] + logging.root.setLevel(logging.INFO) + + +logger = logging.getLogger("central.archive") + + +def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None: + """Build PostGIS geometry from event geo data.""" + if not geo_data: + return None + + bbox = geo_data.get("bbox") + centroid = geo_data.get("centroid") + + if bbox and len(bbox) == 4: + # Create polygon from bbox + min_lon, min_lat, max_lon, max_lat = bbox + return json.dumps({ + "type": "Polygon", + "coordinates": [[ + [min_lon, min_lat], + [max_lon, min_lat], + [max_lon, max_lat], + [min_lon, max_lat], + [min_lon, min_lat], + ]] + }) + elif centroid and len(centroid) == 2: + # Create point from centroid + return json.dumps({ + "type": "Point", + "coordinates": centroid + }) + + return None + + +class ArchiveConsumer: + """Archive consumer process.""" + + def __init__(self, config: Config) -> None: + self.config = config + self._nc: nats.NATS | None = None + self._js: JetStreamContext | None = None + self._pool: asyncpg.Pool | None = None + self._shutdown_event = asyncio.Event() + + async def connect(self) -> None: + """Connect to NATS and PostgreSQL.""" + self._nc = await nats.connect(self.config.nats.url) + self._js = self._nc.jetstream() + logger.info("Connected to NATS", extra={"url": self.config.nats.url}) + + self._pool = await asyncpg.create_pool( + self.config.postgres.dsn, + min_size=1, + max_size=5, + ) + logger.info("Connected to PostgreSQL") + + async def disconnect(self) -> None: + """Disconnect from NATS and PostgreSQL.""" + if self._pool: + await self._pool.close() + self._pool = None + if self._nc: + await self._nc.drain() + await self._nc.close() + self._nc = None + self._js = None + logger.info("Disconnected") + + async def _ensure_consumer(self) -> None: + """Ensure the durable consumer exists.""" + if not self._js: + return + + try: + await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME) + logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME}) + except nats.js.errors.NotFoundError: + consumer_config = ConsumerConfig( + durable_name=CONSUMER_NAME, + deliver_policy=DeliverPolicy.ALL, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=ACK_WAIT, + filter_subject=SUBJECT_FILTER, + ) + await self._js.add_consumer(STREAM_NAME, consumer_config) + logger.info("Consumer created", extra={"consumer": CONSUMER_NAME}) + + async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None: + """Process a single message and insert into database.""" + try: + envelope = json.loads(msg.data.decode()) + except json.JSONDecodeError as e: + logger.warning("Invalid JSON in message", extra={"error": str(e)}) + await msg.ack() + return + + event_data = envelope.get("data", {}) + geo_data = event_data.get("geo") + + event_id = envelope.get("id") + source = event_data.get("source", "") + category = event_data.get("category", "") + time_str = event_data.get("time") + expires_str = event_data.get("expires") + severity = event_data.get("severity") + regions = event_data.get("geo", {}).get("regions", []) + primary_region = event_data.get("geo", {}).get("primary_region") + + # Parse timestamps + event_time = None + if time_str: + try: + event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + + expires_time = None + if expires_str: + try: + expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + + if not event_id or not event_time: + logger.warning( + "Message missing required fields", + extra={"id": event_id, "time": time_str} + ) + await msg.ack() + return + + geom_json = _build_geom_sql(geo_data) + + try: + if geom_json: + await conn.execute( + """ + INSERT INTO events (id, source, category, time, expires, severity, + geom, regions, primary_region, payload) + VALUES ($1, $2, $3, $4, $5, $6, + ST_GeomFromGeoJSON($7), $8, $9, $10) + ON CONFLICT (id, time) DO UPDATE SET + source = EXCLUDED.source, + category = EXCLUDED.category, + expires = EXCLUDED.expires, + severity = EXCLUDED.severity, + geom = EXCLUDED.geom, + regions = EXCLUDED.regions, + primary_region = EXCLUDED.primary_region, + payload = EXCLUDED.payload + """, + event_id, source, category, event_time, expires_time, severity, + geom_json, regions, primary_region, json.dumps(envelope) + ) + else: + await conn.execute( + """ + INSERT INTO events (id, source, category, time, expires, severity, + geom, regions, primary_region, payload) + VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9) + ON CONFLICT (id, time) DO UPDATE SET + source = EXCLUDED.source, + category = EXCLUDED.category, + expires = EXCLUDED.expires, + severity = EXCLUDED.severity, + geom = EXCLUDED.geom, + regions = EXCLUDED.regions, + primary_region = EXCLUDED.primary_region, + payload = EXCLUDED.payload + """, + event_id, source, category, event_time, expires_time, severity, + regions, primary_region, json.dumps(envelope) + ) + + await msg.ack() + logger.info("Archived event", extra={"id": event_id, "category": category}) + + except Exception as e: + logger.error( + "Failed to insert event", + extra={"id": event_id, "error": str(e)} + ) + # Don't ack - let it be redelivered + + async def _consume_loop(self) -> None: + """Main consume loop.""" + if not self._js or not self._pool: + return + + await self._ensure_consumer() + + sub = await self._js.pull_subscribe( + SUBJECT_FILTER, + durable=CONSUMER_NAME, + stream=STREAM_NAME, + ) + + logger.info( + "Subscribed to stream", + extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER} + ) + + while not self._shutdown_event.is_set(): + try: + msgs = await sub.fetch( + batch=BATCH_SIZE, + timeout=FETCH_TIMEOUT, + ) + + if msgs: + async with self._pool.acquire() as conn: + for msg in msgs: + await self._process_message(msg, conn) + + except nats.errors.TimeoutError: + # No messages available, continue + pass + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Error in consume loop", extra={"error": str(e)}) + await asyncio.sleep(1) + + logger.info("Consume loop stopped") + + async def start(self) -> None: + """Start the consumer.""" + await self.connect() + logger.info("Archive consumer ready") + + async def run(self) -> None: + """Run the consume loop until shutdown.""" + await self._consume_loop() + + async def stop(self) -> None: + """Stop the consumer gracefully.""" + logger.info("Archive consumer shutting down") + self._shutdown_event.set() + await self.disconnect() + logger.info("Archive consumer stopped") + + +async def async_main() -> None: + """Async entry point.""" + setup_logging() + + config = load_config(CONFIG_PATH) + consumer = ArchiveConsumer(config) + + loop = asyncio.get_running_loop() + shutdown_event = asyncio.Event() + + def handle_signal() -> None: + shutdown_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, handle_signal) + + await consumer.start() + + # Run consumer in background + consume_task = asyncio.create_task(consumer.run()) + + # Wait for shutdown signal + await shutdown_event.wait() + + consumer._shutdown_event.set() + consume_task.cancel() + try: + await consume_task + except asyncio.CancelledError: + pass + + await consumer.stop() + + +def main() -> None: + """Entry point.""" + asyncio.run(async_main()) + + +if __name__ == "__main__": + main() diff --git a/src/central/supervisor.py b/src/central/supervisor.py new file mode 100644 index 0000000..9139507 --- /dev/null +++ b/src/central/supervisor.py @@ -0,0 +1,255 @@ +"""Central supervisor - adapter scheduler and event publisher.""" + +import asyncio +import json +import logging +import signal +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import nats +from nats.js import JetStreamContext + +from central.adapters.nws import NWSAdapter +from central.cloudevents_wire import wrap_event +from central.config import load_config, Config +from central.models import subject_for_event + +CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") +CONFIG_PATH = "/etc/central/central.toml" + + +class JsonFormatter(logging.Formatter): + """JSON log formatter for structured logging.""" + + def format(self, record: logging.LogRecord) -> str: + log_obj: dict[str, Any] = { + "ts": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "msg": record.getMessage(), + } + if record.exc_info: + log_obj["exc"] = self.formatException(record.exc_info) + if hasattr(record, "extra"): + log_obj.update(record.extra) + # Include any extra fields passed via extra={} + for key in record.__dict__: + if key not in ( + "name", "msg", "args", "created", "filename", "funcName", + "levelname", "levelno", "lineno", "module", "msecs", + "pathname", "process", "processName", "relativeCreated", + "stack_info", "exc_info", "exc_text", "thread", "threadName", + "taskName", "message", + ): + log_obj[key] = record.__dict__[key] + return json.dumps(log_obj) + + +def setup_logging() -> None: + """Configure JSON logging to stdout.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(JsonFormatter()) + logging.root.handlers = [handler] + logging.root.setLevel(logging.INFO) + + +logger = logging.getLogger("central.supervisor") + + +class Supervisor: + """Main supervisor process.""" + + def __init__(self, config: Config) -> None: + self.config = config + self._nc: nats.NATS | None = None + self._js: JetStreamContext | None = None + self._adapters: list[NWSAdapter] = [] + self._tasks: list[asyncio.Task[None]] = [] + self._shutdown_event = asyncio.Event() + self._start_time = datetime.now(timezone.utc) + + async def connect(self) -> None: + """Connect to NATS.""" + self._nc = await nats.connect(self.config.nats.url) + self._js = self._nc.jetstream() + logger.info("Connected to NATS", extra={"url": self.config.nats.url}) + + async def disconnect(self) -> None: + """Disconnect from NATS.""" + if self._nc: + await self._nc.drain() + await self._nc.close() + self._nc = None + self._js = None + logger.info("Disconnected from NATS") + + async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None: + """Publish a meta event (no Nats-Msg-Id).""" + if not self._nc: + return + payload = json.dumps(data).encode() + await self._nc.publish(subject, payload) + + async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None: + """Publish an event with dedup header.""" + if not self._js: + return + payload = json.dumps(envelope).encode() + await self._js.publish( + subject, + payload, + headers={"Nats-Msg-Id": msg_id}, + ) + + async def _run_adapter(self, adapter: NWSAdapter) -> None: + """Run an adapter poll loop.""" + while not self._shutdown_event.is_set(): + poll_start = datetime.now(timezone.utc) + try: + async for event in adapter.poll(): + # Dedup check + if adapter.is_published(event.id): + adapter.bump_last_seen(event.id) + continue + + # Build CloudEvent + envelope, msg_id = wrap_event(event, self.config) + subject = subject_for_event(event) + + # Publish + await self._publish_event(subject, envelope, msg_id) + adapter.mark_published(event.id) + + logger.info( + "Published event", + extra={"id": event.id, "subject": subject, "category": event.category} + ) + + # Publish success status + await self._publish_meta( + f"central.meta.adapter.{adapter.name}.status", + {"ok": True, "ts": datetime.now(timezone.utc).isoformat()} + ) + + except Exception as e: + logger.exception("Adapter poll failed", extra={"adapter": adapter.name}) + await self._publish_meta( + f"central.meta.adapter.{adapter.name}.status", + { + "ok": False, + "error": str(e), + "ts": datetime.now(timezone.utc).isoformat() + } + ) + + # Sweep old IDs + swept = adapter.sweep_old_ids() + if swept > 0: + logger.info("Swept old published IDs", extra={"count": swept}) + + # Sleep until next cadence + elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds() + sleep_time = max(0, adapter.cadence_s - elapsed) + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=sleep_time + ) + except asyncio.TimeoutError: + pass + + async def _heartbeat_loop(self) -> None: + """Publish periodic heartbeats.""" + while not self._shutdown_event.is_set(): + uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() + await self._publish_meta( + "central.meta.heartbeat", + {"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime} + ) + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=30 + ) + except asyncio.TimeoutError: + pass + + async def start(self) -> None: + """Start the supervisor.""" + await self.connect() + + # Initialize adapters + if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled: + adapter = NWSAdapter( + config=self.config.adapters["nws"], + cursor_db_path=CURSOR_DB_PATH, + ) + await adapter.startup() + self._adapters.append(adapter) + logger.info("NWS adapter initialized") + + # Start adapter tasks + for adapter in self._adapters: + task = asyncio.create_task(self._run_adapter(adapter)) + self._tasks.append(task) + + # Start heartbeat + self._tasks.append(asyncio.create_task(self._heartbeat_loop())) + + logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]}) + + async def stop(self) -> None: + """Stop the supervisor gracefully.""" + logger.info("Supervisor shutting down") + self._shutdown_event.set() + + # Cancel tasks + for task in self._tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Shutdown adapters + for adapter in self._adapters: + await adapter.shutdown() + + await self.disconnect() + logger.info("Supervisor stopped") + + +async def async_main() -> None: + """Async entry point.""" + setup_logging() + + config = load_config(CONFIG_PATH) + supervisor = Supervisor(config) + + loop = asyncio.get_running_loop() + shutdown_event = asyncio.Event() + + def handle_signal() -> None: + shutdown_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, handle_signal) + + await supervisor.start() + + # Wait for shutdown signal + await shutdown_event.wait() + + await supervisor.stop() + + +def main() -> None: + """Entry point.""" + asyncio.run(async_main()) + + +if __name__ == "__main__": + main() diff --git a/systemd/central-archive.service b/systemd/central-archive.service new file mode 100644 index 0000000..79e76bb --- /dev/null +++ b/systemd/central-archive.service @@ -0,0 +1,26 @@ +[Unit] +Description=Central archive consumer (JetStream -> TimescaleDB) +After=network-online.target nats-server.service postgresql@16-main.service +Wants=network-online.target +Requires=nats-server.service postgresql@16-main.service + +[Service] +Type=simple +User=central +Group=central +WorkingDirectory=/opt/central +Environment=HOME=/opt/central +ExecStart=/opt/central/.venv/bin/central-archive +Restart=on-failure +RestartSec=5 +LimitNOFILE=65536 +NoNewPrivileges=true +ProtectSystem=full +ProtectHome=true +PrivateTmp=true +ReadWritePaths=/var/lib/central +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/systemd/central-supervisor.service b/systemd/central-supervisor.service new file mode 100644 index 0000000..61ef0cc --- /dev/null +++ b/systemd/central-supervisor.service @@ -0,0 +1,26 @@ +[Unit] +Description=Central supervisor (adapter scheduler + publisher) +After=network-online.target nats-server.service postgresql@16-main.service +Wants=network-online.target +Requires=nats-server.service + +[Service] +Type=simple +User=central +Group=central +WorkingDirectory=/opt/central +Environment=HOME=/opt/central +ExecStart=/opt/central/.venv/bin/central-supervisor +Restart=on-failure +RestartSec=5 +LimitNOFILE=65536 +NoNewPrivileges=true +ProtectSystem=full +ProtectHome=true +PrivateTmp=true +ReadWritePaths=/var/lib/central +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/tests/test_nws_normalization.py b/tests/test_nws_normalization.py new file mode 100644 index 0000000..72e2fdc --- /dev/null +++ b/tests/test_nws_normalization.py @@ -0,0 +1,373 @@ +"""Tests for NWS adapter normalization.""" + +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from central.adapters.nws import ( + NWSAdapter, + _snake_case, + _parse_datetime, + _extract_states_from_codes, + _build_regions, + _compute_centroid, + _compute_bbox, + SEVERITY_MAP, +) +from central.config import NWSAdapterConfig +from central.models import subject_for_event + + +# Sample NWS GeoJSON features for testing +# SAME codes: 6 digits, first 2 are state FIPS (ID=16, OR=41, CA=06, WA=53) +SAMPLE_FEATURE_ID = { + "id": "urn:oid:2.49.0.1.840.0.a1b2c3d4e5f6", + "type": "Feature", + "geometry": { + "type": "Polygon", + "coordinates": [[ + [-116.5, 43.5], + [-116.0, 43.5], + [-116.0, 44.0], + [-116.5, 44.0], + [-116.5, 43.5], + ]] + }, + "properties": { + "id": "urn:oid:2.49.0.1.840.0.a1b2c3d4e5f6", + "event": "Severe Thunderstorm Warning", + "sent": "2026-05-15T12:00:00-06:00", + "expires": "2026-05-15T14:00:00-06:00", + "severity": "Severe", + "geocode": { + "SAME": ["160001"], # Idaho state FIPS 16 + "UGC": ["IDC001", "IDZ033"], + }, + }, +} + +SAMPLE_FEATURE_OR = { + "id": "urn:oid:2.49.0.1.840.0.x1y2z3w4", + "type": "Feature", + "geometry": None, + "properties": { + "id": "urn:oid:2.49.0.1.840.0.x1y2z3w4", + "event": "Winter Storm Warning", + "sent": "2026-05-15T08:00:00Z", + "expires": "2026-05-16T08:00:00Z", + "severity": "Moderate", + "geocode": { + "SAME": ["410051"], # Oregon state FIPS 41 + "UGC": ["ORC051"], + }, + }, +} + +SAMPLE_FEATURE_CA = { + "id": "urn:oid:2.49.0.1.840.0.ca1234", + "type": "Feature", + "geometry": { + "type": "Point", + "coordinates": [-118.25, 34.05], + }, + "properties": { + "id": "urn:oid:2.49.0.1.840.0.ca1234", + "event": "Fire Weather Watch", + "sent": "2026-05-15T10:00:00-07:00", + "expires": "2026-05-16T18:00:00-07:00", + "severity": "Minor", + "geocode": { + "SAME": ["060037"], # California state FIPS 06 + "UGC": ["CAZ568"], + }, + }, +} + +SAMPLE_FEATURE_UNKNOWN_SEVERITY = { + "id": "urn:oid:2.49.0.1.840.0.unk123", + "type": "Feature", + "geometry": None, + "properties": { + "id": "urn:oid:2.49.0.1.840.0.unk123", + "event": "Test Alert", + "sent": "2026-05-15T12:00:00Z", + "expires": None, + "severity": "Unknown", + "geocode": { + "SAME": ["530033"], # Washington state FIPS 53 + "UGC": ["WAC033"], + }, + }, +} + + +class TestSnakeCase: + """Tests for snake_case conversion.""" + + def test_spaces_to_underscores(self) -> None: + assert _snake_case("Severe Thunderstorm Warning") == "severe_thunderstorm_warning" + + def test_removes_special_chars(self) -> None: + assert _snake_case("Fire Weather (Red Flag)") == "fire_weather_red_flag" + + def test_lowercase(self) -> None: + assert _snake_case("TORNADO WARNING") == "tornado_warning" + + +class TestParseDatetime: + """Tests for datetime parsing.""" + + def test_iso_with_offset(self) -> None: + result = _parse_datetime("2026-05-15T12:00:00-06:00") + assert result is not None + assert result.tzinfo == timezone.utc + assert result.hour == 18 # 12:00 MDT = 18:00 UTC + + def test_iso_with_z(self) -> None: + result = _parse_datetime("2026-05-15T12:00:00Z") + assert result is not None + assert result.hour == 12 + + def test_none_input(self) -> None: + assert _parse_datetime(None) is None + + def test_invalid_input(self) -> None: + assert _parse_datetime("not a date") is None + + +class TestExtractStates: + """Tests for state extraction from geocodes.""" + + def test_same_codes(self) -> None: + # Idaho FIPS is 16 + states = _extract_states_from_codes(["160001", "160003"], []) + assert states == {"ID"} + + def test_ugc_codes(self) -> None: + states = _extract_states_from_codes([], ["IDC001", "ORC051"]) + assert states == {"ID", "OR"} + + def test_combined(self) -> None: + # Idaho FIPS is 16 + states = _extract_states_from_codes(["160001"], ["WAC033"]) + assert states == {"ID", "WA"} + + def test_empty(self) -> None: + states = _extract_states_from_codes([], []) + assert states == set() + + +class TestBuildRegions: + """Tests for region string building.""" + + def test_same_to_fips_region(self) -> None: + # Idaho FIPS is 16 + regions = _build_regions(["160001"], []) + assert "US-ID-FIPS160001" in regions + + def test_ugc_county(self) -> None: + regions = _build_regions([], ["IDC001"]) + assert "US-ID-C001" in regions + + def test_ugc_zone(self) -> None: + regions = _build_regions([], ["IDZ033"]) + assert "US-ID-Z033" in regions + + def test_sorted_alphabetically(self) -> None: + regions = _build_regions(["160001"], ["IDC001", "IDZ033"]) + assert regions == sorted(regions) + + +class TestStateFilter: + """Tests for state filtering.""" + + @pytest.fixture + def adapter(self, tmp_path: Path) -> NWSAdapter: + """Create adapter with ID/OR/WA states.""" + config = NWSAdapterConfig( + enabled=True, + cadence_s=60, + states=["ID", "OR", "WA", "MT", "WY", "UT", "NV"], + contact_email="test@example.com", + ) + return NWSAdapter(config, tmp_path / "test.db") + + def test_accepts_id_feature(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_ID) + assert event is not None + assert event.id == SAMPLE_FEATURE_ID["id"] + + def test_accepts_or_feature(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_OR) + assert event is not None + assert event.id == SAMPLE_FEATURE_OR["id"] + + def test_rejects_ca_feature(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_CA) + assert event is None + + +class TestSeverityMapping: + """Tests for severity mapping.""" + + def test_extreme(self) -> None: + assert SEVERITY_MAP["Extreme"] == 4 + + def test_severe(self) -> None: + assert SEVERITY_MAP["Severe"] == 3 + + def test_moderate(self) -> None: + assert SEVERITY_MAP["Moderate"] == 2 + + def test_minor(self) -> None: + assert SEVERITY_MAP["Minor"] == 1 + + def test_unknown(self) -> None: + assert SEVERITY_MAP["Unknown"] is None + + def test_unknown_severity_in_feature(self, tmp_path: Path) -> None: + config = NWSAdapterConfig( + enabled=True, + cadence_s=60, + states=["WA"], + contact_email="test@example.com", + ) + adapter = NWSAdapter(config, tmp_path / "test.db") + event = adapter._normalize_feature(SAMPLE_FEATURE_UNKNOWN_SEVERITY) + assert event is not None + assert event.severity is None + + +class TestSubjectDerivation: + """Tests for NATS subject derivation.""" + + @pytest.fixture + def adapter(self, tmp_path: Path) -> NWSAdapter: + config = NWSAdapterConfig( + enabled=True, + cadence_s=60, + states=["ID", "OR", "WA"], + contact_email="test@example.com", + ) + return NWSAdapter(config, tmp_path / "test.db") + + def test_county_subject(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_ID) + assert event is not None + subject = subject_for_event(event) + # Primary region should be alphabetically first + # Could be county or zone depending on sort order + assert subject.startswith("central.wx.alert.us.id.") + + def test_zone_subject(self, adapter: NWSAdapter) -> None: + # Create feature with only zone codes + feature = { + "id": "urn:test:zone", + "geometry": None, + "properties": { + "event": "Test Alert", + "sent": "2026-05-15T12:00:00Z", + "severity": "Minor", + "geocode": { + "SAME": [], + "UGC": ["IDZ033"], + }, + }, + } + event = adapter._normalize_feature(feature) + assert event is not None + subject = subject_for_event(event) + assert "zone" in subject + + +class TestRegionsSorted: + """Tests for regions list sorting.""" + + @pytest.fixture + def adapter(self, tmp_path: Path) -> NWSAdapter: + config = NWSAdapterConfig( + enabled=True, + cadence_s=60, + states=["ID"], + contact_email="test@example.com", + ) + return NWSAdapter(config, tmp_path / "test.db") + + def test_regions_alphabetically_sorted(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_ID) + assert event is not None + assert event.geo.regions == sorted(event.geo.regions) + + def test_primary_region_is_first(self, adapter: NWSAdapter) -> None: + event = adapter._normalize_feature(SAMPLE_FEATURE_ID) + assert event is not None + assert len(event.geo.regions) > 0 + assert event.geo.primary_region == event.geo.regions[0] + + +class TestDeduplication: + """Tests for event deduplication.""" + + @pytest.fixture + def adapter(self, tmp_path: Path) -> NWSAdapter: + config = NWSAdapterConfig( + enabled=True, + cadence_s=60, + states=["ID"], + contact_email="test@example.com", + ) + return NWSAdapter(config, tmp_path / "test.db") + + def test_same_feature_same_id(self, adapter: NWSAdapter) -> None: + """Normalizing the same feature twice returns same Event.id.""" + event1 = adapter._normalize_feature(SAMPLE_FEATURE_ID) + event2 = adapter._normalize_feature(SAMPLE_FEATURE_ID) + assert event1 is not None + assert event2 is not None + assert event1.id == event2.id + + +class TestGeometry: + """Tests for geometry computation.""" + + def test_centroid_polygon(self) -> None: + geom = { + "type": "Polygon", + "coordinates": [[ + [-116.5, 43.5], + [-116.0, 43.5], + [-116.0, 44.0], + [-116.5, 44.0], + [-116.5, 43.5], + ]] + } + centroid = _compute_centroid(geom) + assert centroid is not None + # Average of 5 vertices (including closing point) + # lon: (-116.5 + -116.0 + -116.0 + -116.5 + -116.5) / 5 = -116.3 + # lat: (43.5 + 43.5 + 44.0 + 44.0 + 43.5) / 5 = 43.7 + assert -116.4 < centroid[0] < -116.2 + assert 43.6 < centroid[1] < 43.8 + + def test_bbox_polygon(self) -> None: + geom = { + "type": "Polygon", + "coordinates": [[ + [-116.5, 43.5], + [-116.0, 43.5], + [-116.0, 44.0], + [-116.5, 44.0], + [-116.5, 43.5], + ]] + } + bbox = _compute_bbox(geom) + assert bbox is not None + assert bbox == (-116.5, 43.5, -116.0, 44.0) + + def test_centroid_none_geometry(self) -> None: + assert _compute_centroid(None) is None + + def test_bbox_none_geometry(self) -> None: + assert _compute_bbox(None) is None