runtime: NWS adapter, supervisor, archive consumer, systemd units

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Matt Johnson 2026-05-15 21:29:01 +00:00 committed by Ubuntu
commit 31be17430d
8 changed files with 1480 additions and 0 deletions

View file

@ -20,6 +20,10 @@ dependencies = [
"tenacity>=9.1.4", "tenacity>=9.1.4",
] ]
[project.scripts]
central-supervisor = "central.supervisor:main"
central-archive = "central.archive:main"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/central"] packages = ["src/central"]

View file

@ -0,0 +1 @@
__version__ = "0.1.0"

453
src/central/adapters/nws.py Normal file
View file

@ -0,0 +1,453 @@
"""NWS (National Weather Service) alert adapter."""
import asyncio
import logging
import re
import sqlite3
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiohttp
from aiolimiter import AsyncLimiter
from tenacity import (
retry,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception_type,
)
from central import __version__
from central.adapter import SourceAdapter
from central.config import NWSAdapterConfig
from central.models import Event, Geo
logger = logging.getLogger(__name__)
# FIPS state codes to postal abbreviations
FIPS_TO_STATE: dict[str, str] = {
"01": "AL", "02": "AK", "04": "AZ", "05": "AR", "06": "CA",
"08": "CO", "09": "CT", "10": "DE", "11": "DC", "12": "FL",
"13": "GA", "15": "HI", "16": "ID", "17": "IL", "18": "IN",
"19": "IA", "20": "KS", "21": "KY", "22": "LA", "23": "ME",
"24": "MD", "25": "MA", "26": "MI", "27": "MN", "28": "MS",
"29": "MO", "30": "MT", "31": "NE", "32": "NV", "33": "NH",
"34": "NJ", "35": "NM", "36": "NY", "37": "NC", "38": "ND",
"39": "OH", "40": "OK", "41": "OR", "42": "PA", "44": "RI",
"45": "SC", "46": "SD", "47": "TN", "48": "TX", "49": "UT",
"50": "VT", "51": "VA", "53": "WA", "54": "WV", "55": "WI",
"56": "WY", "60": "AS", "66": "GU", "69": "MP", "72": "PR",
"78": "VI",
}
SEVERITY_MAP: dict[str, int | None] = {
"Extreme": 4,
"Severe": 3,
"Moderate": 2,
"Minor": 1,
"Unknown": None,
}
NWS_API_URL = "https://api.weather.gov/alerts/active"
def _snake_case(s: str) -> str:
"""Convert a string to snake_case."""
s = re.sub(r"[^a-zA-Z0-9\s]", "", s)
s = re.sub(r"\s+", "_", s.strip())
return s.lower()
def _parse_datetime(s: str | None) -> datetime | None:
"""Parse an ISO datetime string to UTC datetime."""
if not s:
return None
try:
dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
return dt.astimezone(timezone.utc)
except (ValueError, TypeError):
return None
def _compute_centroid(geometry: dict[str, Any] | None) -> tuple[float, float] | None:
"""Compute centroid from GeoJSON geometry using arithmetic mean of vertices."""
if not geometry:
return None
geom_type = geometry.get("type")
coords = geometry.get("coordinates")
if not coords:
return None
all_points: list[tuple[float, float]] = []
if geom_type == "Point":
return (coords[0], coords[1])
elif geom_type == "Polygon":
for ring in coords:
for point in ring:
all_points.append((point[0], point[1]))
elif geom_type == "MultiPolygon":
for polygon in coords:
for ring in polygon:
for point in ring:
all_points.append((point[0], point[1]))
else:
return None
if not all_points:
return None
avg_lon = sum(p[0] for p in all_points) / len(all_points)
avg_lat = sum(p[1] for p in all_points) / len(all_points)
return (avg_lon, avg_lat)
def _compute_bbox(
geometry: dict[str, Any] | None
) -> tuple[float, float, float, float] | None:
"""Compute bounding box from GeoJSON geometry."""
if not geometry:
return None
geom_type = geometry.get("type")
coords = geometry.get("coordinates")
if not coords:
return None
all_points: list[tuple[float, float]] = []
if geom_type == "Point":
return (coords[0], coords[1], coords[0], coords[1])
elif geom_type == "Polygon":
for ring in coords:
for point in ring:
all_points.append((point[0], point[1]))
elif geom_type == "MultiPolygon":
for polygon in coords:
for ring in polygon:
for point in ring:
all_points.append((point[0], point[1]))
else:
return None
if not all_points:
return None
min_lon = min(p[0] for p in all_points)
max_lon = max(p[0] for p in all_points)
min_lat = min(p[1] for p in all_points)
max_lat = max(p[1] for p in all_points)
return (min_lon, min_lat, max_lon, max_lat)
def _extract_states_from_codes(
same_codes: list[str], ugc_codes: list[str]
) -> set[str]:
"""Extract state abbreviations from SAME and UGC codes."""
states: set[str] = set()
for code in same_codes:
if len(code) >= 2:
fips_state = code[:2]
if fips_state in FIPS_TO_STATE:
states.add(FIPS_TO_STATE[fips_state])
for code in ugc_codes:
if len(code) >= 2 and code[:2].isalpha():
states.add(code[:2].upper())
return states
def _build_regions(same_codes: list[str], ugc_codes: list[str]) -> list[str]:
"""Build sorted list of region strings from geocodes."""
regions: set[str] = set()
for code in same_codes:
if len(code) >= 2:
fips_state = code[:2]
if fips_state in FIPS_TO_STATE:
state = FIPS_TO_STATE[fips_state]
regions.add(f"US-{state}-FIPS{code}")
for code in ugc_codes:
if len(code) >= 3 and code[:2].isalpha():
state = code[:2].upper()
rest = code[2:]
if rest.startswith("C"):
regions.add(f"US-{state}-C{rest[1:]}")
elif rest.startswith("Z"):
regions.add(f"US-{state}-Z{rest[1:]}")
else:
regions.add(f"US-{state}-{rest}")
return sorted(regions)
class NWSAdapter(SourceAdapter):
"""National Weather Service alerts adapter."""
name = "nws"
def __init__(
self,
config: NWSAdapterConfig,
cursor_db_path: Path,
) -> None:
self.config = config
self.cadence_s = config.cadence_s
self.states = set(s.upper() for s in config.states)
self.cursor_db_path = cursor_db_path
self._session: aiohttp.ClientSession | None = None
self._limiter = AsyncLimiter(1, config.cadence_s)
self._db: sqlite3.Connection | None = None
async def startup(self) -> None:
"""Initialize HTTP session and cursor database."""
user_agent = f"Central/{__version__} ({self.config.contact_email})"
self._session = aiohttp.ClientSession(
headers={"User-Agent": user_agent},
timeout=aiohttp.ClientTimeout(total=30),
)
self._db = sqlite3.connect(str(self.cursor_db_path))
self._db.execute("""
CREATE TABLE IF NOT EXISTS adapter_cursors (
adapter TEXT PRIMARY KEY,
cursor_data TEXT NOT NULL,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
self._db.execute("""
CREATE TABLE IF NOT EXISTS published_ids (
adapter TEXT NOT NULL,
event_id TEXT NOT NULL,
first_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (adapter, event_id)
)
""")
self._db.execute("""
CREATE INDEX IF NOT EXISTS published_ids_last_seen
ON published_ids (last_seen)
""")
self._db.commit()
logger.info("NWS adapter started", extra={"states": list(self.states)})
async def shutdown(self) -> None:
"""Close HTTP session and database."""
if self._session:
await self._session.close()
self._session = None
if self._db:
self._db.close()
self._db = None
logger.info("NWS adapter shut down")
def _get_cursor(self) -> str | None:
"""Get the stored If-Modified-Since cursor."""
if not self._db:
return None
cur = self._db.execute(
"SELECT cursor_data FROM adapter_cursors WHERE adapter = ?",
(self.name,)
)
row = cur.fetchone()
return row[0] if row else None
def _set_cursor(self, last_modified: str) -> None:
"""Store the Last-Modified header for next request."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO adapter_cursors (adapter, cursor_data, updated)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT (adapter) DO UPDATE SET
cursor_data = excluded.cursor_data,
updated = CURRENT_TIMESTAMP
""",
(self.name, last_modified)
)
self._db.commit()
def is_published(self, event_id: str) -> bool:
"""Check if an event has already been published."""
if not self._db:
return False
cur = self._db.execute(
"SELECT 1 FROM published_ids WHERE adapter = ? AND event_id = ?",
(self.name, event_id)
)
return cur.fetchone() is not None
def mark_published(self, event_id: str) -> None:
"""Mark an event as published."""
if not self._db:
return
self._db.execute(
"""
INSERT INTO published_ids (adapter, event_id, first_seen, last_seen)
VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (adapter, event_id) DO UPDATE SET
last_seen = CURRENT_TIMESTAMP
""",
(self.name, event_id)
)
self._db.commit()
def bump_last_seen(self, event_id: str) -> None:
"""Bump the last_seen timestamp for an event."""
if not self._db:
return
self._db.execute(
"UPDATE published_ids SET last_seen = CURRENT_TIMESTAMP WHERE adapter = ? AND event_id = ?",
(self.name, event_id)
)
self._db.commit()
def sweep_old_ids(self) -> int:
"""Remove published_ids older than 8 days. Returns count deleted."""
if not self._db:
return 0
cur = self._db.execute(
"DELETE FROM published_ids WHERE last_seen < datetime('now', '-8 days')"
)
self._db.commit()
return cur.rowcount
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(initial=1, max=60),
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
reraise=True,
)
async def _fetch_alerts(self) -> tuple[int, dict[str, Any] | None, str | None]:
"""Fetch alerts from NWS API with conditional request."""
async with self._limiter:
if not self._session:
raise RuntimeError("Session not initialized")
headers: dict[str, str] = {}
cursor = self._get_cursor()
if cursor:
headers["If-Modified-Since"] = cursor
async with self._session.get(NWS_API_URL, headers=headers) as resp:
if resp.status in (429, 403):
retry_after = resp.headers.get("Retry-After", "60")
try:
wait_time = int(retry_after)
except ValueError:
wait_time = 60
logger.warning(
"Rate limited by NWS",
extra={"status": resp.status, "retry_after": wait_time}
)
await asyncio.sleep(wait_time)
raise aiohttp.ClientError(f"Rate limited: {resp.status}")
if resp.status == 304:
return (304, None, None)
resp.raise_for_status()
data = await resp.json()
last_modified = resp.headers.get("Last-Modified")
return (resp.status, data, last_modified)
def _normalize_feature(self, feature: dict[str, Any]) -> Event | None:
"""Normalize a GeoJSON feature to an Event."""
props = feature.get("properties", {})
geocode = props.get("geocode", {})
same_codes = geocode.get("SAME", [])
ugc_codes = geocode.get("UGC", [])
feature_states = _extract_states_from_codes(same_codes, ugc_codes)
if not feature_states.intersection(self.states):
return None
event_id = feature.get("id")
if not event_id:
logger.warning("Feature missing id", extra={"properties": props})
return None
event_type = props.get("event", "Unknown")
category = f"wx.alert.{_snake_case(event_type)}"
time = _parse_datetime(props.get("sent"))
if not time:
logger.warning("Feature missing sent time", extra={"id": event_id})
return None
expires = _parse_datetime(props.get("expires"))
severity_str = props.get("severity", "Unknown")
severity = SEVERITY_MAP.get(severity_str)
geometry = feature.get("geometry")
centroid = _compute_centroid(geometry)
bbox = _compute_bbox(geometry)
regions = _build_regions(same_codes, ugc_codes)
primary_region = regions[0] if regions else None
geo = Geo(
centroid=centroid,
bbox=bbox,
regions=regions,
primary_region=primary_region,
)
return Event(
id=event_id,
source="central/adapters/nws",
category=category,
time=time,
expires=expires,
severity=severity,
geo=geo,
data=props,
)
async def poll(self) -> AsyncIterator[Event]:
"""Poll NWS API for active alerts."""
try:
status, data, last_modified = await self._fetch_alerts()
except Exception as e:
logger.error("Failed to fetch NWS alerts", extra={"error": str(e)})
raise
if status == 304:
logger.info("NWS returned 304 Not Modified")
return
if last_modified:
self._set_cursor(last_modified)
features = data.get("features", []) if data else []
logger.info(
"NWS poll completed",
extra={"status": status, "feature_count": len(features)}
)
yielded = 0
for feature in features:
try:
event = self._normalize_feature(feature)
if event:
yield event
yielded += 1
except Exception as e:
logger.warning(
"Failed to normalize feature",
extra={"error": str(e), "feature_id": feature.get("id")}
)
logger.info("NWS yielded events", extra={"count": yielded})

342
src/central/archive.py Normal file
View file

@ -0,0 +1,342 @@
"""Central archive consumer - JetStream to TimescaleDB."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Any
import asyncpg
import nats
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
from central.config import load_config, Config
CONFIG_PATH = "/etc/central/central.toml"
CONSUMER_NAME = "archive"
STREAM_NAME = "CENTRAL_WX"
SUBJECT_FILTER = "central.wx.>"
BATCH_SIZE = 100
FETCH_TIMEOUT = 5.0
ACK_WAIT = 30
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.archive")
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
"""Build PostGIS geometry from event geo data."""
if not geo_data:
return None
bbox = geo_data.get("bbox")
centroid = geo_data.get("centroid")
if bbox and len(bbox) == 4:
# Create polygon from bbox
min_lon, min_lat, max_lon, max_lat = bbox
return json.dumps({
"type": "Polygon",
"coordinates": [[
[min_lon, min_lat],
[max_lon, min_lat],
[max_lon, max_lat],
[min_lon, max_lat],
[min_lon, min_lat],
]]
})
elif centroid and len(centroid) == 2:
# Create point from centroid
return json.dumps({
"type": "Point",
"coordinates": centroid
})
return None
class ArchiveConsumer:
"""Archive consumer process."""
def __init__(self, config: Config) -> None:
self.config = config
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._pool: asyncpg.Pool | None = None
self._shutdown_event = asyncio.Event()
async def connect(self) -> None:
"""Connect to NATS and PostgreSQL."""
self._nc = await nats.connect(self.config.nats.url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
self._pool = await asyncpg.create_pool(
self.config.postgres.dsn,
min_size=1,
max_size=5,
)
logger.info("Connected to PostgreSQL")
async def disconnect(self) -> None:
"""Disconnect from NATS and PostgreSQL."""
if self._pool:
await self._pool.close()
self._pool = None
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected")
async def _ensure_consumer(self) -> None:
"""Ensure the durable consumer exists."""
if not self._js:
return
try:
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
except nats.js.errors.NotFoundError:
consumer_config = ConsumerConfig(
durable_name=CONSUMER_NAME,
deliver_policy=DeliverPolicy.ALL,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=ACK_WAIT,
filter_subject=SUBJECT_FILTER,
)
await self._js.add_consumer(STREAM_NAME, consumer_config)
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
"""Process a single message and insert into database."""
try:
envelope = json.loads(msg.data.decode())
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in message", extra={"error": str(e)})
await msg.ack()
return
event_data = envelope.get("data", {})
geo_data = event_data.get("geo")
event_id = envelope.get("id")
source = event_data.get("source", "")
category = event_data.get("category", "")
time_str = event_data.get("time")
expires_str = event_data.get("expires")
severity = event_data.get("severity")
regions = event_data.get("geo", {}).get("regions", [])
primary_region = event_data.get("geo", {}).get("primary_region")
# Parse timestamps
event_time = None
if time_str:
try:
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
expires_time = None
if expires_str:
try:
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
if not event_id or not event_time:
logger.warning(
"Message missing required fields",
extra={"id": event_id, "time": time_str}
)
await msg.ack()
return
geom_json = _build_geom_sql(geo_data)
try:
if geom_json:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6,
ST_GeomFromGeoJSON($7), $8, $9, $10)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
geom_json, regions, primary_region, json.dumps(envelope)
)
else:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
regions, primary_region, json.dumps(envelope)
)
await msg.ack()
logger.info("Archived event", extra={"id": event_id, "category": category})
except Exception as e:
logger.error(
"Failed to insert event",
extra={"id": event_id, "error": str(e)}
)
# Don't ack - let it be redelivered
async def _consume_loop(self) -> None:
"""Main consume loop."""
if not self._js or not self._pool:
return
await self._ensure_consumer()
sub = await self._js.pull_subscribe(
SUBJECT_FILTER,
durable=CONSUMER_NAME,
stream=STREAM_NAME,
)
logger.info(
"Subscribed to stream",
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
)
while not self._shutdown_event.is_set():
try:
msgs = await sub.fetch(
batch=BATCH_SIZE,
timeout=FETCH_TIMEOUT,
)
if msgs:
async with self._pool.acquire() as conn:
for msg in msgs:
await self._process_message(msg, conn)
except nats.errors.TimeoutError:
# No messages available, continue
pass
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Error in consume loop", extra={"error": str(e)})
await asyncio.sleep(1)
logger.info("Consume loop stopped")
async def start(self) -> None:
"""Start the consumer."""
await self.connect()
logger.info("Archive consumer ready")
async def run(self) -> None:
"""Run the consume loop until shutdown."""
await self._consume_loop()
async def stop(self) -> None:
"""Stop the consumer gracefully."""
logger.info("Archive consumer shutting down")
self._shutdown_event.set()
await self.disconnect()
logger.info("Archive consumer stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
config = load_config(CONFIG_PATH)
consumer = ArchiveConsumer(config)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await consumer.start()
# Run consumer in background
consume_task = asyncio.create_task(consumer.run())
# Wait for shutdown signal
await shutdown_event.wait()
consumer._shutdown_event.set()
consume_task.cancel()
try:
await consume_task
except asyncio.CancelledError:
pass
await consumer.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

255
src/central/supervisor.py Normal file
View file

@ -0,0 +1,255 @@
"""Central supervisor - adapter scheduler and event publisher."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import nats
from nats.js import JetStreamContext
from central.adapters.nws import NWSAdapter
from central.cloudevents_wire import wrap_event
from central.config import load_config, Config
from central.models import subject_for_event
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
CONFIG_PATH = "/etc/central/central.toml"
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
if hasattr(record, "extra"):
log_obj.update(record.extra)
# Include any extra fields passed via extra={}
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.supervisor")
class Supervisor:
"""Main supervisor process."""
def __init__(self, config: Config) -> None:
self.config = config
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._adapters: list[NWSAdapter] = []
self._tasks: list[asyncio.Task[None]] = []
self._shutdown_event = asyncio.Event()
self._start_time = datetime.now(timezone.utc)
async def connect(self) -> None:
"""Connect to NATS."""
self._nc = await nats.connect(self.config.nats.url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
async def disconnect(self) -> None:
"""Disconnect from NATS."""
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected from NATS")
async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None:
"""Publish a meta event (no Nats-Msg-Id)."""
if not self._nc:
return
payload = json.dumps(data).encode()
await self._nc.publish(subject, payload)
async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None:
"""Publish an event with dedup header."""
if not self._js:
return
payload = json.dumps(envelope).encode()
await self._js.publish(
subject,
payload,
headers={"Nats-Msg-Id": msg_id},
)
async def _run_adapter(self, adapter: NWSAdapter) -> None:
"""Run an adapter poll loop."""
while not self._shutdown_event.is_set():
poll_start = datetime.now(timezone.utc)
try:
async for event in adapter.poll():
# Dedup check
if adapter.is_published(event.id):
adapter.bump_last_seen(event.id)
continue
# Build CloudEvent
envelope, msg_id = wrap_event(event, self.config)
subject = subject_for_event(event)
# Publish
await self._publish_event(subject, envelope, msg_id)
adapter.mark_published(event.id)
logger.info(
"Published event",
extra={"id": event.id, "subject": subject, "category": event.category}
)
# Publish success status
await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status",
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
)
except Exception as e:
logger.exception("Adapter poll failed", extra={"adapter": adapter.name})
await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status",
{
"ok": False,
"error": str(e),
"ts": datetime.now(timezone.utc).isoformat()
}
)
# Sweep old IDs
swept = adapter.sweep_old_ids()
if swept > 0:
logger.info("Swept old published IDs", extra={"count": swept})
# Sleep until next cadence
elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds()
sleep_time = max(0, adapter.cadence_s - elapsed)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=sleep_time
)
except asyncio.TimeoutError:
pass
async def _heartbeat_loop(self) -> None:
"""Publish periodic heartbeats."""
while not self._shutdown_event.is_set():
uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds()
await self._publish_meta(
"central.meta.heartbeat",
{"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime}
)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=30
)
except asyncio.TimeoutError:
pass
async def start(self) -> None:
"""Start the supervisor."""
await self.connect()
# Initialize adapters
if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled:
adapter = NWSAdapter(
config=self.config.adapters["nws"],
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
self._adapters.append(adapter)
logger.info("NWS adapter initialized")
# Start adapter tasks
for adapter in self._adapters:
task = asyncio.create_task(self._run_adapter(adapter))
self._tasks.append(task)
# Start heartbeat
self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]})
async def stop(self) -> None:
"""Stop the supervisor gracefully."""
logger.info("Supervisor shutting down")
self._shutdown_event.set()
# Cancel tasks
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Shutdown adapters
for adapter in self._adapters:
await adapter.shutdown()
await self.disconnect()
logger.info("Supervisor stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
config = load_config(CONFIG_PATH)
supervisor = Supervisor(config)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await supervisor.start()
# Wait for shutdown signal
await shutdown_event.wait()
await supervisor.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

View file

@ -0,0 +1,26 @@
[Unit]
Description=Central archive consumer (JetStream -> TimescaleDB)
After=network-online.target nats-server.service postgresql@16-main.service
Wants=network-online.target
Requires=nats-server.service postgresql@16-main.service
[Service]
Type=simple
User=central
Group=central
WorkingDirectory=/opt/central
Environment=HOME=/opt/central
ExecStart=/opt/central/.venv/bin/central-archive
Restart=on-failure
RestartSec=5
LimitNOFILE=65536
NoNewPrivileges=true
ProtectSystem=full
ProtectHome=true
PrivateTmp=true
ReadWritePaths=/var/lib/central
StandardOutput=journal
StandardError=journal
[Install]
WantedBy=multi-user.target

View file

@ -0,0 +1,26 @@
[Unit]
Description=Central supervisor (adapter scheduler + publisher)
After=network-online.target nats-server.service postgresql@16-main.service
Wants=network-online.target
Requires=nats-server.service
[Service]
Type=simple
User=central
Group=central
WorkingDirectory=/opt/central
Environment=HOME=/opt/central
ExecStart=/opt/central/.venv/bin/central-supervisor
Restart=on-failure
RestartSec=5
LimitNOFILE=65536
NoNewPrivileges=true
ProtectSystem=full
ProtectHome=true
PrivateTmp=true
ReadWritePaths=/var/lib/central
StandardOutput=journal
StandardError=journal
[Install]
WantedBy=multi-user.target

View file

@ -0,0 +1,373 @@
"""Tests for NWS adapter normalization."""
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from central.adapters.nws import (
NWSAdapter,
_snake_case,
_parse_datetime,
_extract_states_from_codes,
_build_regions,
_compute_centroid,
_compute_bbox,
SEVERITY_MAP,
)
from central.config import NWSAdapterConfig
from central.models import subject_for_event
# Sample NWS GeoJSON features for testing
# SAME codes: 6 digits, first 2 are state FIPS (ID=16, OR=41, CA=06, WA=53)
SAMPLE_FEATURE_ID = {
"id": "urn:oid:2.49.0.1.840.0.a1b2c3d4e5f6",
"type": "Feature",
"geometry": {
"type": "Polygon",
"coordinates": [[
[-116.5, 43.5],
[-116.0, 43.5],
[-116.0, 44.0],
[-116.5, 44.0],
[-116.5, 43.5],
]]
},
"properties": {
"id": "urn:oid:2.49.0.1.840.0.a1b2c3d4e5f6",
"event": "Severe Thunderstorm Warning",
"sent": "2026-05-15T12:00:00-06:00",
"expires": "2026-05-15T14:00:00-06:00",
"severity": "Severe",
"geocode": {
"SAME": ["160001"], # Idaho state FIPS 16
"UGC": ["IDC001", "IDZ033"],
},
},
}
SAMPLE_FEATURE_OR = {
"id": "urn:oid:2.49.0.1.840.0.x1y2z3w4",
"type": "Feature",
"geometry": None,
"properties": {
"id": "urn:oid:2.49.0.1.840.0.x1y2z3w4",
"event": "Winter Storm Warning",
"sent": "2026-05-15T08:00:00Z",
"expires": "2026-05-16T08:00:00Z",
"severity": "Moderate",
"geocode": {
"SAME": ["410051"], # Oregon state FIPS 41
"UGC": ["ORC051"],
},
},
}
SAMPLE_FEATURE_CA = {
"id": "urn:oid:2.49.0.1.840.0.ca1234",
"type": "Feature",
"geometry": {
"type": "Point",
"coordinates": [-118.25, 34.05],
},
"properties": {
"id": "urn:oid:2.49.0.1.840.0.ca1234",
"event": "Fire Weather Watch",
"sent": "2026-05-15T10:00:00-07:00",
"expires": "2026-05-16T18:00:00-07:00",
"severity": "Minor",
"geocode": {
"SAME": ["060037"], # California state FIPS 06
"UGC": ["CAZ568"],
},
},
}
SAMPLE_FEATURE_UNKNOWN_SEVERITY = {
"id": "urn:oid:2.49.0.1.840.0.unk123",
"type": "Feature",
"geometry": None,
"properties": {
"id": "urn:oid:2.49.0.1.840.0.unk123",
"event": "Test Alert",
"sent": "2026-05-15T12:00:00Z",
"expires": None,
"severity": "Unknown",
"geocode": {
"SAME": ["530033"], # Washington state FIPS 53
"UGC": ["WAC033"],
},
},
}
class TestSnakeCase:
"""Tests for snake_case conversion."""
def test_spaces_to_underscores(self) -> None:
assert _snake_case("Severe Thunderstorm Warning") == "severe_thunderstorm_warning"
def test_removes_special_chars(self) -> None:
assert _snake_case("Fire Weather (Red Flag)") == "fire_weather_red_flag"
def test_lowercase(self) -> None:
assert _snake_case("TORNADO WARNING") == "tornado_warning"
class TestParseDatetime:
"""Tests for datetime parsing."""
def test_iso_with_offset(self) -> None:
result = _parse_datetime("2026-05-15T12:00:00-06:00")
assert result is not None
assert result.tzinfo == timezone.utc
assert result.hour == 18 # 12:00 MDT = 18:00 UTC
def test_iso_with_z(self) -> None:
result = _parse_datetime("2026-05-15T12:00:00Z")
assert result is not None
assert result.hour == 12
def test_none_input(self) -> None:
assert _parse_datetime(None) is None
def test_invalid_input(self) -> None:
assert _parse_datetime("not a date") is None
class TestExtractStates:
"""Tests for state extraction from geocodes."""
def test_same_codes(self) -> None:
# Idaho FIPS is 16
states = _extract_states_from_codes(["160001", "160003"], [])
assert states == {"ID"}
def test_ugc_codes(self) -> None:
states = _extract_states_from_codes([], ["IDC001", "ORC051"])
assert states == {"ID", "OR"}
def test_combined(self) -> None:
# Idaho FIPS is 16
states = _extract_states_from_codes(["160001"], ["WAC033"])
assert states == {"ID", "WA"}
def test_empty(self) -> None:
states = _extract_states_from_codes([], [])
assert states == set()
class TestBuildRegions:
"""Tests for region string building."""
def test_same_to_fips_region(self) -> None:
# Idaho FIPS is 16
regions = _build_regions(["160001"], [])
assert "US-ID-FIPS160001" in regions
def test_ugc_county(self) -> None:
regions = _build_regions([], ["IDC001"])
assert "US-ID-C001" in regions
def test_ugc_zone(self) -> None:
regions = _build_regions([], ["IDZ033"])
assert "US-ID-Z033" in regions
def test_sorted_alphabetically(self) -> None:
regions = _build_regions(["160001"], ["IDC001", "IDZ033"])
assert regions == sorted(regions)
class TestStateFilter:
"""Tests for state filtering."""
@pytest.fixture
def adapter(self, tmp_path: Path) -> NWSAdapter:
"""Create adapter with ID/OR/WA states."""
config = NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID", "OR", "WA", "MT", "WY", "UT", "NV"],
contact_email="test@example.com",
)
return NWSAdapter(config, tmp_path / "test.db")
def test_accepts_id_feature(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event is not None
assert event.id == SAMPLE_FEATURE_ID["id"]
def test_accepts_or_feature(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_OR)
assert event is not None
assert event.id == SAMPLE_FEATURE_OR["id"]
def test_rejects_ca_feature(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_CA)
assert event is None
class TestSeverityMapping:
"""Tests for severity mapping."""
def test_extreme(self) -> None:
assert SEVERITY_MAP["Extreme"] == 4
def test_severe(self) -> None:
assert SEVERITY_MAP["Severe"] == 3
def test_moderate(self) -> None:
assert SEVERITY_MAP["Moderate"] == 2
def test_minor(self) -> None:
assert SEVERITY_MAP["Minor"] == 1
def test_unknown(self) -> None:
assert SEVERITY_MAP["Unknown"] is None
def test_unknown_severity_in_feature(self, tmp_path: Path) -> None:
config = NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["WA"],
contact_email="test@example.com",
)
adapter = NWSAdapter(config, tmp_path / "test.db")
event = adapter._normalize_feature(SAMPLE_FEATURE_UNKNOWN_SEVERITY)
assert event is not None
assert event.severity is None
class TestSubjectDerivation:
"""Tests for NATS subject derivation."""
@pytest.fixture
def adapter(self, tmp_path: Path) -> NWSAdapter:
config = NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID", "OR", "WA"],
contact_email="test@example.com",
)
return NWSAdapter(config, tmp_path / "test.db")
def test_county_subject(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event is not None
subject = subject_for_event(event)
# Primary region should be alphabetically first
# Could be county or zone depending on sort order
assert subject.startswith("central.wx.alert.us.id.")
def test_zone_subject(self, adapter: NWSAdapter) -> None:
# Create feature with only zone codes
feature = {
"id": "urn:test:zone",
"geometry": None,
"properties": {
"event": "Test Alert",
"sent": "2026-05-15T12:00:00Z",
"severity": "Minor",
"geocode": {
"SAME": [],
"UGC": ["IDZ033"],
},
},
}
event = adapter._normalize_feature(feature)
assert event is not None
subject = subject_for_event(event)
assert "zone" in subject
class TestRegionsSorted:
"""Tests for regions list sorting."""
@pytest.fixture
def adapter(self, tmp_path: Path) -> NWSAdapter:
config = NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID"],
contact_email="test@example.com",
)
return NWSAdapter(config, tmp_path / "test.db")
def test_regions_alphabetically_sorted(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event is not None
assert event.geo.regions == sorted(event.geo.regions)
def test_primary_region_is_first(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event is not None
assert len(event.geo.regions) > 0
assert event.geo.primary_region == event.geo.regions[0]
class TestDeduplication:
"""Tests for event deduplication."""
@pytest.fixture
def adapter(self, tmp_path: Path) -> NWSAdapter:
config = NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID"],
contact_email="test@example.com",
)
return NWSAdapter(config, tmp_path / "test.db")
def test_same_feature_same_id(self, adapter: NWSAdapter) -> None:
"""Normalizing the same feature twice returns same Event.id."""
event1 = adapter._normalize_feature(SAMPLE_FEATURE_ID)
event2 = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event1 is not None
assert event2 is not None
assert event1.id == event2.id
class TestGeometry:
"""Tests for geometry computation."""
def test_centroid_polygon(self) -> None:
geom = {
"type": "Polygon",
"coordinates": [[
[-116.5, 43.5],
[-116.0, 43.5],
[-116.0, 44.0],
[-116.5, 44.0],
[-116.5, 43.5],
]]
}
centroid = _compute_centroid(geom)
assert centroid is not None
# Average of 5 vertices (including closing point)
# lon: (-116.5 + -116.0 + -116.0 + -116.5 + -116.5) / 5 = -116.3
# lat: (43.5 + 43.5 + 44.0 + 44.0 + 43.5) / 5 = 43.7
assert -116.4 < centroid[0] < -116.2
assert 43.6 < centroid[1] < 43.8
def test_bbox_polygon(self) -> None:
geom = {
"type": "Polygon",
"coordinates": [[
[-116.5, 43.5],
[-116.0, 43.5],
[-116.0, 44.0],
[-116.5, 44.0],
[-116.5, 43.5],
]]
}
bbox = _compute_bbox(geom)
assert bbox is not None
assert bbox == (-116.5, 43.5, -116.0, 44.0)
def test_centroid_none_geometry(self) -> None:
assert _compute_centroid(None) is None
def test_bbox_none_geometry(self) -> None:
assert _compute_bbox(None) is None