Merge pull request #2 from zvx-echo6/feature/1a-service-cutover

feat(config): Phase 1a-3 service cutover to DB-backed config
This commit is contained in:
malice 2026-05-15 21:08:41 -06:00 committed by GitHub
commit b3788d556d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2501 additions and 604 deletions

View file

@ -1,342 +1,353 @@
"""Central archive consumer - JetStream to TimescaleDB."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Any
import asyncpg
import nats
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
from central.config import load_config, Config
CONFIG_PATH = "/etc/central/central.toml"
CONSUMER_NAME = "archive"
STREAM_NAME = "CENTRAL_WX"
SUBJECT_FILTER = "central.wx.>"
BATCH_SIZE = 100
FETCH_TIMEOUT = 5.0
ACK_WAIT = 30
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.archive")
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
"""Build PostGIS geometry from event geo data."""
if not geo_data:
return None
bbox = geo_data.get("bbox")
centroid = geo_data.get("centroid")
if bbox and len(bbox) == 4:
# Create polygon from bbox
min_lon, min_lat, max_lon, max_lat = bbox
return json.dumps({
"type": "Polygon",
"coordinates": [[
[min_lon, min_lat],
[max_lon, min_lat],
[max_lon, max_lat],
[min_lon, max_lat],
[min_lon, min_lat],
]]
})
elif centroid and len(centroid) == 2:
# Create point from centroid
return json.dumps({
"type": "Point",
"coordinates": centroid
})
return None
class ArchiveConsumer:
"""Archive consumer process."""
def __init__(self, config: Config) -> None:
self.config = config
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._pool: asyncpg.Pool | None = None
self._shutdown_event = asyncio.Event()
async def connect(self) -> None:
"""Connect to NATS and PostgreSQL."""
self._nc = await nats.connect(self.config.nats.url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
self._pool = await asyncpg.create_pool(
self.config.postgres.dsn,
min_size=1,
max_size=5,
)
logger.info("Connected to PostgreSQL")
async def disconnect(self) -> None:
"""Disconnect from NATS and PostgreSQL."""
if self._pool:
await self._pool.close()
self._pool = None
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected")
async def _ensure_consumer(self) -> None:
"""Ensure the durable consumer exists."""
if not self._js:
return
try:
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
except nats.js.errors.NotFoundError:
consumer_config = ConsumerConfig(
durable_name=CONSUMER_NAME,
deliver_policy=DeliverPolicy.ALL,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=ACK_WAIT,
filter_subject=SUBJECT_FILTER,
)
await self._js.add_consumer(STREAM_NAME, consumer_config)
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
"""Process a single message and insert into database."""
try:
envelope = json.loads(msg.data.decode())
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in message", extra={"error": str(e)})
await msg.ack()
return
event_data = envelope.get("data", {})
geo_data = event_data.get("geo")
event_id = envelope.get("id")
source = event_data.get("source", "")
category = event_data.get("category", "")
time_str = event_data.get("time")
expires_str = event_data.get("expires")
severity = event_data.get("severity")
regions = event_data.get("geo", {}).get("regions", [])
primary_region = event_data.get("geo", {}).get("primary_region")
# Parse timestamps
event_time = None
if time_str:
try:
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
expires_time = None
if expires_str:
try:
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
if not event_id or not event_time:
logger.warning(
"Message missing required fields",
extra={"id": event_id, "time": time_str}
)
await msg.ack()
return
geom_json = _build_geom_sql(geo_data)
try:
if geom_json:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6,
ST_GeomFromGeoJSON($7), $8, $9, $10)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
geom_json, regions, primary_region, json.dumps(envelope)
)
else:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
regions, primary_region, json.dumps(envelope)
)
await msg.ack()
logger.info("Archived event", extra={"id": event_id, "category": category})
except Exception as e:
logger.error(
"Failed to insert event",
extra={"id": event_id, "error": str(e)}
)
# Don't ack - let it be redelivered
async def _consume_loop(self) -> None:
"""Main consume loop."""
if not self._js or not self._pool:
return
await self._ensure_consumer()
sub = await self._js.pull_subscribe(
SUBJECT_FILTER,
durable=CONSUMER_NAME,
stream=STREAM_NAME,
)
logger.info(
"Subscribed to stream",
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
)
while not self._shutdown_event.is_set():
try:
msgs = await sub.fetch(
batch=BATCH_SIZE,
timeout=FETCH_TIMEOUT,
)
if msgs:
async with self._pool.acquire() as conn:
for msg in msgs:
await self._process_message(msg, conn)
except nats.errors.TimeoutError:
# No messages available, continue
pass
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Error in consume loop", extra={"error": str(e)})
await asyncio.sleep(1)
logger.info("Consume loop stopped")
async def start(self) -> None:
"""Start the consumer."""
await self.connect()
logger.info("Archive consumer ready")
async def run(self) -> None:
"""Run the consume loop until shutdown."""
await self._consume_loop()
async def stop(self) -> None:
"""Stop the consumer gracefully."""
logger.info("Archive consumer shutting down")
self._shutdown_event.set()
await self.disconnect()
logger.info("Archive consumer stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
config = load_config(CONFIG_PATH)
consumer = ArchiveConsumer(config)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await consumer.start()
# Run consumer in background
consume_task = asyncio.create_task(consumer.run())
# Wait for shutdown signal
await shutdown_event.wait()
consumer._shutdown_event.set()
consume_task.cancel()
try:
await consume_task
except asyncio.CancelledError:
pass
await consumer.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()
"""Central archive consumer - JetStream to TimescaleDB."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Any
import asyncpg
import nats
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
from central.bootstrap_config import get_settings
CONSUMER_NAME = "archive"
STREAM_NAME = "CENTRAL_WX"
SUBJECT_FILTER = "central.wx.>"
BATCH_SIZE = 100
FETCH_TIMEOUT = 5.0
ACK_WAIT = 30
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.archive")
def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
"""Build PostGIS geometry from event geo data."""
if not geo_data:
return None
bbox = geo_data.get("bbox")
centroid = geo_data.get("centroid")
if bbox and len(bbox) == 4:
# Create polygon from bbox
min_lon, min_lat, max_lon, max_lat = bbox
return json.dumps({
"type": "Polygon",
"coordinates": [[
[min_lon, min_lat],
[max_lon, min_lat],
[max_lon, max_lat],
[min_lon, max_lat],
[min_lon, min_lat],
]]
})
elif centroid and len(centroid) == 2:
# Create point from centroid
return json.dumps({
"type": "Point",
"coordinates": centroid
})
return None
class ArchiveConsumer:
"""Archive consumer process."""
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
self._nats_url = nats_url
self._postgres_dsn = postgres_dsn
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._pool: asyncpg.Pool | None = None
self._shutdown_event = asyncio.Event()
async def connect(self) -> None:
"""Connect to NATS and PostgreSQL."""
self._nc = await nats.connect(self._nats_url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self._nats_url})
self._pool = await asyncpg.create_pool(
self._postgres_dsn,
min_size=1,
max_size=5,
)
logger.info("Connected to PostgreSQL")
async def disconnect(self) -> None:
"""Disconnect from NATS and PostgreSQL."""
if self._pool:
await self._pool.close()
self._pool = None
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected")
async def _ensure_consumer(self) -> None:
"""Ensure the durable consumer exists."""
if not self._js:
return
try:
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
except nats.js.errors.NotFoundError:
consumer_config = ConsumerConfig(
durable_name=CONSUMER_NAME,
deliver_policy=DeliverPolicy.ALL,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=ACK_WAIT,
filter_subject=SUBJECT_FILTER,
)
await self._js.add_consumer(STREAM_NAME, consumer_config)
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
"""Process a single message and insert into database."""
try:
envelope = json.loads(msg.data.decode())
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in message", extra={"error": str(e)})
await msg.ack()
return
event_data = envelope.get("data", {})
geo_data = event_data.get("geo")
event_id = envelope.get("id")
source = event_data.get("source", "")
category = event_data.get("category", "")
time_str = event_data.get("time")
expires_str = event_data.get("expires")
severity = event_data.get("severity")
regions = event_data.get("geo", {}).get("regions", [])
primary_region = event_data.get("geo", {}).get("primary_region")
# Parse timestamps
event_time = None
if time_str:
try:
event_time = datetime.fromisoformat(time_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
expires_time = None
if expires_str:
try:
expires_time = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
except (ValueError, TypeError):
pass
if not event_id or not event_time:
logger.warning(
"Message missing required fields",
extra={"id": event_id, "time": time_str}
)
await msg.ack()
return
geom_json = _build_geom_sql(geo_data)
try:
if geom_json:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6,
ST_GeomFromGeoJSON($7), $8, $9, $10)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
geom_json, regions, primary_region, json.dumps(envelope)
)
else:
await conn.execute(
"""
INSERT INTO events (id, source, category, time, expires, severity,
geom, regions, primary_region, payload)
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
ON CONFLICT (id, time) DO UPDATE SET
source = EXCLUDED.source,
category = EXCLUDED.category,
expires = EXCLUDED.expires,
severity = EXCLUDED.severity,
geom = EXCLUDED.geom,
regions = EXCLUDED.regions,
primary_region = EXCLUDED.primary_region,
payload = EXCLUDED.payload
""",
event_id, source, category, event_time, expires_time, severity,
regions, primary_region, json.dumps(envelope)
)
await msg.ack()
logger.info("Archived event", extra={"id": event_id, "category": category})
except Exception as e:
logger.error(
"Failed to insert event",
extra={"id": event_id, "error": str(e)}
)
# Don't ack - let it be redelivered
async def _consume_loop(self) -> None:
"""Main consume loop."""
if not self._js or not self._pool:
return
await self._ensure_consumer()
sub = await self._js.pull_subscribe(
SUBJECT_FILTER,
durable=CONSUMER_NAME,
stream=STREAM_NAME,
)
logger.info(
"Subscribed to stream",
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
)
while not self._shutdown_event.is_set():
try:
msgs = await sub.fetch(
batch=BATCH_SIZE,
timeout=FETCH_TIMEOUT,
)
if msgs:
async with self._pool.acquire() as conn:
for msg in msgs:
await self._process_message(msg, conn)
except nats.errors.TimeoutError:
# No messages available, continue
pass
except asyncio.CancelledError:
break
except Exception as e:
logger.exception("Error in consume loop", extra={"error": str(e)})
await asyncio.sleep(1)
logger.info("Consume loop stopped")
async def start(self) -> None:
"""Start the consumer."""
await self.connect()
logger.info("Archive consumer ready")
async def run(self) -> None:
"""Run the consume loop until shutdown."""
await self._consume_loop()
async def stop(self) -> None:
"""Stop the consumer gracefully."""
logger.info("Archive consumer shutting down")
self._shutdown_event.set()
await self.disconnect()
logger.info("Archive consumer stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
settings = get_settings()
logger.info(
"Archive starting",
extra={
"nats_url": settings.nats_url,
"config_source": settings.config_source,
},
)
consumer = ArchiveConsumer(
nats_url=settings.nats_url,
postgres_dsn=settings.db_dsn,
)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await consumer.start()
# Run consumer in background
consume_task = asyncio.create_task(consumer.run())
# Wait for shutdown signal
await shutdown_event.wait()
consumer._shutdown_event.set()
consume_task.cancel()
try:
await consume_task
except asyncio.CancelledError:
pass
await consumer.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

View file

@ -9,7 +9,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Literal
from pydantic import Field
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
@ -33,6 +33,21 @@ class Settings(BaseSettings):
default="INFO",
description="Logging level",
)
config_source: Literal["toml", "db"] = Field(
default="toml",
description="Configuration source: 'toml' for TOML file, 'db' for database",
)
config_toml_path: Path = Field(
default=Path("/etc/central/central.toml"),
description="Path to TOML config file (when config_source=toml)",
)
@field_validator("config_source")
@classmethod
def validate_config_source(cls, v: str) -> str:
if v not in ("toml", "db"):
raise ValueError(f"config_source must be 'toml' or 'db', got {v!r}")
return v
@lru_cache

View file

@ -0,0 +1,15 @@
"""CloudEvents configuration constants.
These are the protocol-level constants for CloudEvents envelope format.
CloudEvents envelope format is part of the Central protocol contract
and is not operator-configurable.
"""
from central.config import CloudEventsConfig
# CloudEvents protocol constants
CLOUDEVENTS_CONFIG = CloudEventsConfig(
type_prefix="central",
source="central.echo6.co",
schema_version="1.0",
)

View file

@ -1,30 +1,48 @@
"""CloudEvents wire format helpers."""
from typing import Any
from typing import Any, Union
from cloudevents.v1.http import CloudEvent
from central.config import Config
from central.config import Config, CloudEventsConfig
from central.cloudevents_constants import CLOUDEVENTS_CONFIG
from central.models import Event
def wrap_event(event: Event, config: Config) -> tuple[dict[str, Any], str]:
def wrap_event(
event: Event,
config: Union[Config, CloudEventsConfig, None] = None,
) -> tuple[dict[str, Any], str]:
"""
Wrap an Event into a CNCF CloudEvents v1.0 JSON envelope.
Args:
event: The event to wrap
config: Either a full Config object, a CloudEventsConfig object,
or None to use defaults.
Returns:
A tuple of (envelope_dict, msg_id) where msg_id is the
CloudEvent id for use as Nats-Msg-Id header.
"""
# Resolve CloudEventsConfig from various input types
if config is None:
ce_config = CLOUDEVENTS_CONFIG
elif isinstance(config, CloudEventsConfig):
ce_config = config
else:
# It's a full Config object
ce_config = config.cloudevents
# Build CE type: {prefix}.{category}.v1
ce_type = f"{config.cloudevents.type_prefix}.{event.category}.v1"
ce_type = f"{ce_config.type_prefix}.{event.category}.v1"
# Serialize event data
event_data = event.model_dump(mode="json")
# Build extension attributes - lowercase, no underscores per CE spec
extensions: dict[str, Any] = {
"centralschemaversion": config.cloudevents.schema_version,
"centralschemaversion": ce_config.schema_version,
"centralcategory": event.category,
}
@ -36,7 +54,7 @@ def wrap_event(event: Event, config: Config) -> tuple[dict[str, Any], str]:
ce = CloudEvent(
attributes={
"id": event.id,
"source": config.cloudevents.source,
"source": ce_config.source,
"type": ce_type,
"time": event.time.isoformat(),
"datacontenttype": "application/json",

View file

@ -0,0 +1,187 @@
"""Configuration source abstraction.
Provides a unified interface for loading adapter configuration from
either TOML files or the database-backed config store.
"""
import logging
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, Protocol, runtime_checkable
import tomllib
from central.config import NWSAdapterConfig
from central.config_models import AdapterConfig
from central.config_store import ConfigStore
logger = logging.getLogger(__name__)
@runtime_checkable
class ConfigSource(Protocol):
"""Protocol for configuration sources."""
async def list_enabled_adapters(self) -> list[AdapterConfig]:
"""List all enabled adapters."""
...
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get configuration for a specific adapter."""
...
async def watch_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Watch for configuration changes.
For TOML source, this is a no-op (returns immediately).
For DB source, this runs forever, calling callback(table, key) on changes.
"""
...
async def close(self) -> None:
"""Clean up resources."""
...
class TomlConfigSource:
"""Configuration source backed by a TOML file.
This is the legacy configuration path. Does not support hot-reload.
"""
def __init__(self, toml_path: Path) -> None:
self._toml_path = toml_path
self._adapters: dict[str, AdapterConfig] = {}
self._loaded = False
def _load(self) -> None:
"""Load configuration from TOML file."""
if self._loaded:
return
with self._toml_path.open("rb") as f:
data = tomllib.load(f)
adapters_raw = data.get("adapters", {})
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
for name, adapter_data in adapters_raw.items():
# Convert TOML adapter config to unified AdapterConfig
# TOML uses NWSAdapterConfig shape, we need to convert to AdapterConfig
enabled = adapter_data.get("enabled", True)
cadence_s = adapter_data.get("cadence_s", 60)
# Extract settings (everything except enabled/cadence_s)
settings = {
k: v
for k, v in adapter_data.items()
if k not in ("enabled", "cadence_s")
}
self._adapters[name] = AdapterConfig(
name=name,
enabled=enabled,
cadence_s=cadence_s,
settings=settings,
paused_at=None,
updated_at=now,
)
self._loaded = True
logger.info(
"Loaded TOML config",
extra={"path": str(self._toml_path), "adapters": list(self._adapters.keys())},
)
async def list_enabled_adapters(self) -> list[AdapterConfig]:
"""List all enabled adapters from TOML."""
self._load()
return [a for a in self._adapters.values() if a.enabled and not a.is_paused]
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get a specific adapter from TOML."""
self._load()
return self._adapters.get(name)
async def watch_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""TOML does not support hot-reload. Returns immediately."""
logger.debug("TOML config source does not support hot-reload")
return
async def close(self) -> None:
"""No resources to clean up for TOML source."""
pass
class DbConfigSource:
"""Configuration source backed by the Postgres config store.
Supports hot-reload via LISTEN/NOTIFY.
"""
def __init__(self, config_store: ConfigStore) -> None:
self._store = config_store
@classmethod
async def create(cls, dsn: str) -> "DbConfigSource":
"""Create a DbConfigSource with a new ConfigStore."""
store = await ConfigStore.create(dsn)
return cls(store)
async def list_enabled_adapters(self) -> list[AdapterConfig]:
"""List all enabled adapters from database."""
all_adapters = await self._store.list_adapters()
return [a for a in all_adapters if a.enabled and not a.is_paused]
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get a specific adapter from database."""
return await self._store.get_adapter(name)
async def watch_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Watch for changes via Postgres LISTEN/NOTIFY.
Runs forever, calling callback(table, key) on each change.
"""
await self._store.listen_for_changes(callback)
async def close(self) -> None:
"""Close the underlying config store."""
await self._store.close()
async def create_config_source(
source_type: str,
dsn: str | None = None,
toml_path: Path | None = None,
) -> ConfigSource:
"""Factory function to create the appropriate config source.
Args:
source_type: "toml" or "db"
dsn: PostgreSQL DSN (required for "db")
toml_path: Path to TOML file (required for "toml")
Returns:
ConfigSource implementation
"""
if source_type == "toml":
if toml_path is None:
raise ValueError("toml_path required for toml config source")
return TomlConfigSource(toml_path)
elif source_type == "db":
if dsn is None:
raise ValueError("dsn required for db config source")
return await DbConfigSource.create(dsn)
else:
raise ValueError(f"Unknown config source type: {source_type}")

View file

@ -1,255 +1,638 @@
"""Central supervisor - adapter scheduler and event publisher."""
import asyncio
import json
import logging
import signal
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import nats
from nats.js import JetStreamContext
from central.adapters.nws import NWSAdapter
from central.cloudevents_wire import wrap_event
from central.config import load_config, Config
from central.models import subject_for_event
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
CONFIG_PATH = "/etc/central/central.toml"
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
if hasattr(record, "extra"):
log_obj.update(record.extra)
# Include any extra fields passed via extra={}
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.supervisor")
class Supervisor:
"""Main supervisor process."""
def __init__(self, config: Config) -> None:
self.config = config
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._adapters: list[NWSAdapter] = []
self._tasks: list[asyncio.Task[None]] = []
self._shutdown_event = asyncio.Event()
self._start_time = datetime.now(timezone.utc)
async def connect(self) -> None:
"""Connect to NATS."""
self._nc = await nats.connect(self.config.nats.url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
async def disconnect(self) -> None:
"""Disconnect from NATS."""
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected from NATS")
async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None:
"""Publish a meta event (no Nats-Msg-Id)."""
if not self._nc:
return
payload = json.dumps(data).encode()
await self._nc.publish(subject, payload)
async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None:
"""Publish an event with dedup header."""
if not self._js:
return
payload = json.dumps(envelope).encode()
await self._js.publish(
subject,
payload,
headers={"Nats-Msg-Id": msg_id},
)
async def _run_adapter(self, adapter: NWSAdapter) -> None:
"""Run an adapter poll loop."""
while not self._shutdown_event.is_set():
poll_start = datetime.now(timezone.utc)
try:
async for event in adapter.poll():
# Dedup check
if adapter.is_published(event.id):
adapter.bump_last_seen(event.id)
continue
# Build CloudEvent
envelope, msg_id = wrap_event(event, self.config)
subject = subject_for_event(event)
# Publish
await self._publish_event(subject, envelope, msg_id)
adapter.mark_published(event.id)
logger.info(
"Published event",
extra={"id": event.id, "subject": subject, "category": event.category}
)
# Publish success status
await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status",
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
)
except Exception as e:
logger.exception("Adapter poll failed", extra={"adapter": adapter.name})
await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status",
{
"ok": False,
"error": str(e),
"ts": datetime.now(timezone.utc).isoformat()
}
)
# Sweep old IDs
swept = adapter.sweep_old_ids()
if swept > 0:
logger.info("Swept old published IDs", extra={"count": swept})
# Sleep until next cadence
elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds()
sleep_time = max(0, adapter.cadence_s - elapsed)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=sleep_time
)
except asyncio.TimeoutError:
pass
async def _heartbeat_loop(self) -> None:
"""Publish periodic heartbeats."""
while not self._shutdown_event.is_set():
uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds()
await self._publish_meta(
"central.meta.heartbeat",
{"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime}
)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=30
)
except asyncio.TimeoutError:
pass
async def start(self) -> None:
"""Start the supervisor."""
await self.connect()
# Initialize adapters
if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled:
adapter = NWSAdapter(
config=self.config.adapters["nws"],
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
self._adapters.append(adapter)
logger.info("NWS adapter initialized")
# Start adapter tasks
for adapter in self._adapters:
task = asyncio.create_task(self._run_adapter(adapter))
self._tasks.append(task)
# Start heartbeat
self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]})
async def stop(self) -> None:
"""Stop the supervisor gracefully."""
logger.info("Supervisor shutting down")
self._shutdown_event.set()
# Cancel tasks
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Shutdown adapters
for adapter in self._adapters:
await adapter.shutdown()
await self.disconnect()
logger.info("Supervisor stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
config = load_config(CONFIG_PATH)
supervisor = Supervisor(config)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await supervisor.start()
# Wait for shutdown signal
await shutdown_event.wait()
await supervisor.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()
"""Central supervisor - adapter scheduler and event publisher."""
import asyncio
import json
import logging
import signal
import sys
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import nats
from nats.js import JetStreamContext
from central.adapters.nws import NWSAdapter
from central.cloudevents_wire import wrap_event
from central.config import load_config, Config, NWSAdapterConfig
from central.config_models import AdapterConfig
from central.config_source import ConfigSource, create_config_source
from central.bootstrap_config import get_settings
from central.models import subject_for_event
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
class JsonFormatter(logging.Formatter):
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
log_obj: dict[str, Any] = {
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
log_obj["exc"] = self.formatException(record.exc_info)
if hasattr(record, "extra"):
log_obj.update(record.extra)
for key in record.__dict__:
if key not in (
"name", "msg", "args", "created", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs",
"pathname", "process", "processName", "relativeCreated",
"stack_info", "exc_info", "exc_text", "thread", "threadName",
"taskName", "message",
):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.supervisor")
@dataclass
class AdapterState:
"""Runtime state for a scheduled adapter."""
name: str
adapter: NWSAdapter
config: AdapterConfig
task: asyncio.Task[None] | None = None
last_completed_poll: datetime | None = None
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
@property
def is_running(self) -> bool:
"""Check if adapter loop is currently running."""
return self.task is not None and not self.task.done()
class Supervisor:
"""Main supervisor process."""
def __init__(
self,
config_source: ConfigSource,
nats_url: str,
cloudevents_config: Any = None,
) -> None:
self._config_source = config_source
self._nats_url = nats_url
self._cloudevents_config = cloudevents_config
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None
self._adapter_states: dict[str, AdapterState] = {}
self._tasks: list[asyncio.Task[None]] = []
self._shutdown_event = asyncio.Event()
self._start_time = datetime.now(timezone.utc)
self._config_watch_task: asyncio.Task[None] | None = None
self._lock = asyncio.Lock()
async def connect(self) -> None:
"""Connect to NATS."""
self._nc = await nats.connect(self._nats_url)
self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self._nats_url})
async def disconnect(self) -> None:
"""Disconnect from NATS."""
if self._nc:
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("Disconnected from NATS")
async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None:
"""Publish a meta event (no Nats-Msg-Id)."""
if not self._nc:
return
payload = json.dumps(data).encode()
await self._nc.publish(subject, payload)
async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None:
"""Publish an event with dedup header."""
if not self._js:
return
payload = json.dumps(envelope).encode()
await self._js.publish(
subject,
payload,
headers={"Nats-Msg-Id": msg_id},
)
def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig:
"""Convert unified AdapterConfig to NWSAdapterConfig."""
return NWSAdapterConfig(
enabled=config.enabled,
cadence_s=config.cadence_s,
states=config.settings.get("states", []),
contact_email=config.settings.get("contact_email", ""),
)
async def _run_adapter_loop(self, state: AdapterState) -> None:
"""Run an adapter poll loop with rate-limit aware scheduling."""
while not self._shutdown_event.is_set():
# Calculate next poll time based on rate-limit guarantee
now = datetime.now(timezone.utc)
if state.last_completed_poll is not None:
next_poll_at = state.last_completed_poll.timestamp() + state.config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
else:
# First poll - run immediately
wait_time = 0
if wait_time > 0:
logger.debug(
"Waiting for next poll",
extra={
"adapter": state.name,
"wait_s": wait_time,
"next_poll": datetime.fromtimestamp(
now.timestamp() + wait_time, tz=timezone.utc
).isoformat(),
},
)
# Wait for either timeout or cancel signal
try:
await asyncio.wait_for(
state.cancel_event.wait(),
timeout=wait_time,
)
# Cancel event was set - check if we should exit or reschedule
if self._shutdown_event.is_set():
break
# Clear the cancel event and re-evaluate schedule
state.cancel_event.clear()
continue
except asyncio.TimeoutError:
pass
# Check shutdown before polling
if self._shutdown_event.is_set():
break
# Check if adapter is still enabled
if not state.config.enabled or state.config.is_paused:
logger.info(
"Adapter disabled/paused, stopping loop",
extra={"adapter": state.name},
)
break
poll_start = datetime.now(timezone.utc)
try:
async for event in state.adapter.poll():
# Dedup check
if state.adapter.is_published(event.id):
state.adapter.bump_last_seen(event.id)
continue
# Build CloudEvent (uses defaults if no config provided)
envelope, msg_id = wrap_event(event, self._cloudevents_config)
subject = subject_for_event(event)
# Publish
await self._publish_event(subject, envelope, msg_id)
state.adapter.mark_published(event.id)
logger.info(
"Published event",
extra={"id": event.id, "subject": subject, "category": event.category}
)
# Publish success status
await self._publish_meta(
f"central.meta.adapter.{state.name}.status",
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
)
# Mark poll completion time for rate limiting
state.last_completed_poll = datetime.now(timezone.utc)
except Exception as e:
logger.exception("Adapter poll failed", extra={"adapter": state.name})
await self._publish_meta(
f"central.meta.adapter.{state.name}.status",
{
"ok": False,
"error": str(e),
"ts": datetime.now(timezone.utc).isoformat()
}
)
# Still mark completion time to avoid tight retry loops
state.last_completed_poll = datetime.now(timezone.utc)
# Sweep old IDs
swept = state.adapter.sweep_old_ids()
if swept > 0:
logger.info("Swept old published IDs", extra={"count": swept})
async def _start_adapter(self, config: AdapterConfig) -> None:
"""Start an adapter based on its configuration.
If the adapter was previously stopped (state exists but task is not running),
reuses the existing state to preserve last_completed_poll for rate limiting.
"""
existing_state = self._adapter_states.get(config.name)
if existing_state is not None:
if existing_state.is_running:
logger.warning(
"Adapter already running",
extra={"adapter": config.name},
)
return
# Adapter was stopped - restart with preserved state
# Update config and restart the adapter
existing_state.config = config
existing_state.cancel_event.clear()
# Reinitialize the adapter with new config
nws_config = self._adapter_config_to_nws_config(config)
existing_state.adapter = NWSAdapter(
config=nws_config,
cursor_db_path=CURSOR_DB_PATH,
)
await existing_state.adapter.startup()
# Start the loop task
existing_state.task = asyncio.create_task(
self._run_adapter_loop(existing_state)
)
# Calculate next poll time for logging
if existing_state.last_completed_poll:
next_poll_at = datetime.fromtimestamp(
existing_state.last_completed_poll.timestamp() + config.cadence_s,
tz=timezone.utc,
)
if next_poll_at <= datetime.now(timezone.utc):
next_poll_at = datetime.now(timezone.utc)
else:
next_poll_at = datetime.now(timezone.utc)
logger.info(
"Adapter restarted",
extra={
"adapter": config.name,
"cadence_s": config.cadence_s,
"preserved_last_poll": existing_state.last_completed_poll.isoformat()
if existing_state.last_completed_poll
else None,
"next_poll": next_poll_at.isoformat(),
},
)
return
# New adapter - create fresh state
if config.name == "nws":
nws_config = self._adapter_config_to_nws_config(config)
adapter = NWSAdapter(
config=nws_config,
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
state = AdapterState(
name=config.name,
adapter=adapter,
config=config,
)
state.task = asyncio.create_task(self._run_adapter_loop(state))
self._adapter_states[config.name] = state
logger.info(
"Adapter started",
extra={
"adapter": config.name,
"cadence_s": config.cadence_s,
},
)
else:
logger.warning(
"Unknown adapter type",
extra={"adapter": config.name},
)
async def _stop_adapter(self, name: str) -> None:
"""Stop a running adapter but preserve state for potential restart.
The adapter state (including last_completed_poll) is preserved so that
if the adapter is re-enabled, the rate-limit guarantee is maintained.
Use _remove_adapter() to fully remove an adapter from tracking.
"""
state = self._adapter_states.get(name)
if state is None:
return
if not state.is_running:
# Already stopped
return
# Signal the loop to stop
state.cancel_event.set()
if state.task:
state.task.cancel()
try:
await state.task
except asyncio.CancelledError:
pass
state.task = None
await state.adapter.shutdown()
logger.info(
"Adapter stopped",
extra={
"adapter": name,
"preserved_last_poll": state.last_completed_poll.isoformat()
if state.last_completed_poll
else None,
},
)
async def _remove_adapter(self, name: str) -> None:
"""Fully remove an adapter, dropping all preserved state.
Called when an adapter is deleted from the database (not just disabled).
"""
state = self._adapter_states.pop(name, None)
if state is None:
return
# Stop if running
if state.is_running:
state.cancel_event.set()
if state.task:
state.task.cancel()
try:
await state.task
except asyncio.CancelledError:
pass
await state.adapter.shutdown()
logger.info(
"Adapter removed",
extra={"adapter": name},
)
async def _reschedule_adapter(
self,
name: str,
new_config: AdapterConfig,
) -> None:
"""Reschedule an adapter with new configuration.
Maintains rate-limit guarantee: next poll at
(last_completed_poll + new_cadence_s), not now + new_cadence_s.
"""
state = self._adapter_states.get(name)
if state is None:
# Adapter not running - just start it
await self._start_adapter(new_config)
return
if not state.is_running:
# Adapter stopped - restart it
await self._start_adapter(new_config)
return
old_cadence = state.config.cadence_s
new_cadence = new_config.cadence_s
# Update config
state.config = new_config
# Update adapter's cadence
state.adapter.cadence_s = new_cadence
# Update adapter settings if needed (e.g., states list)
if name == "nws":
nws_config = self._adapter_config_to_nws_config(new_config)
state.adapter.states = set(s.upper() for s in nws_config.states)
# Calculate next poll time for logging
if state.last_completed_poll:
next_poll_at = datetime.fromtimestamp(
state.last_completed_poll.timestamp() + new_cadence,
tz=timezone.utc,
)
else:
next_poll_at = datetime.now(timezone.utc)
logger.info(
"Rescheduled adapter",
extra={
"adapter": name,
"old_cadence_s": old_cadence,
"new_cadence_s": new_cadence,
"next_poll": next_poll_at.isoformat(),
},
)
# Signal the loop to re-evaluate its schedule
state.cancel_event.set()
async def _on_config_change(self, table: str, key: str) -> None:
"""Handle a configuration change notification.
Called when NOTIFY fires for config changes.
"""
if table != "adapters":
return
adapter_name = key
logger.info(
"Config change received",
extra={"table": table, "key": key},
)
async with self._lock:
# Fetch the current config for this adapter
new_config = await self._config_source.get_adapter(adapter_name)
current_state = self._adapter_states.get(adapter_name)
if new_config is None:
# Adapter was deleted - fully remove, don't just stop
if current_state:
await self._remove_adapter(adapter_name)
logger.info(
"Adapter deleted, removed",
extra={"adapter": adapter_name},
)
return
if not new_config.enabled or new_config.is_paused:
# Adapter disabled or paused - stop but preserve state
if current_state and current_state.is_running:
await self._stop_adapter(adapter_name)
logger.info(
"Adapter disabled/paused, stopped",
extra={
"adapter": adapter_name,
"enabled": new_config.enabled,
"paused": new_config.is_paused,
},
)
return
if current_state is None or not current_state.is_running:
# Adapter was enabled or created - start (will reuse state if exists)
await self._start_adapter(new_config)
logger.info(
"Adapter enabled, started",
extra={"adapter": adapter_name},
)
else:
# Adapter config changed (cadence, settings)
await self._reschedule_adapter(adapter_name, new_config)
async def _heartbeat_loop(self) -> None:
"""Publish periodic heartbeats."""
while not self._shutdown_event.is_set():
uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds()
await self._publish_meta(
"central.meta.heartbeat",
{"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime}
)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=30
)
except asyncio.TimeoutError:
pass
async def start(self) -> None:
"""Start the supervisor."""
await self.connect()
# Load and start enabled adapters
enabled_adapters = await self._config_source.list_enabled_adapters()
for config in enabled_adapters:
await self._start_adapter(config)
# Start config watcher (for DB source, this runs forever; for TOML, returns immediately)
self._config_watch_task = asyncio.create_task(
self._config_source.watch_for_changes(self._on_config_change)
)
# Start heartbeat
self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
logger.info(
"Supervisor started",
extra={"adapters": list(self._adapter_states.keys())},
)
async def stop(self) -> None:
"""Stop the supervisor gracefully."""
logger.info("Supervisor shutting down")
self._shutdown_event.set()
# Cancel config watcher
if self._config_watch_task:
self._config_watch_task.cancel()
try:
await self._config_watch_task
except asyncio.CancelledError:
pass
# Cancel heartbeat and other tasks
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Remove all adapters (full cleanup)
for name in list(self._adapter_states.keys()):
await self._remove_adapter(name)
# Close config source
await self._config_source.close()
await self.disconnect()
logger.info("Supervisor stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
settings = get_settings()
logger.info(
"Config source: %s",
settings.config_source,
extra={"config_source": settings.config_source},
)
# Create config source based on setting
config_source = await create_config_source(
source_type=settings.config_source,
dsn=settings.db_dsn,
toml_path=settings.config_toml_path,
)
# CloudEvents config: try TOML first, fall back to code defaults
# (CloudEvents envelope format is protocol-level, not operator-configurable)
cloudevents_config = None
if settings.config_source == "toml":
try:
toml_config = load_config(str(settings.config_toml_path))
cloudevents_config = toml_config
except Exception:
pass # Will use defaults from cloudevents_constants
supervisor = Supervisor(
config_source=config_source,
nats_url=settings.nats_url,
cloudevents_config=cloudevents_config,
)
logger.info(
"CloudEvents config: %s",
"TOML" if cloudevents_config else "defaults",
extra={"cloudevents_source": "toml" if cloudevents_config else "defaults"},
)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await supervisor.start()
# Wait for shutdown signal
await shutdown_event.wait()
await supervisor.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

41
systemd/README.md Normal file
View file

@ -0,0 +1,41 @@
# Systemd Unit Files
These unit files configure Central services for systemd.
## Installation
```bash
# Copy unit files
sudo cp central-supervisor.service /etc/systemd/system/
sudo cp central-archive.service /etc/systemd/system/
# Reload systemd
sudo systemctl daemon-reload
# Enable and start services
sudo systemctl enable --now central-supervisor
sudo systemctl enable --now central-archive
```
## Configuration
Both services load environment variables from `/etc/central/central.env`:
```bash
CENTRAL_DB_DSN=postgresql://central:password@localhost/central
CENTRAL_NATS_URL=nats://localhost:4222
CENTRAL_CONFIG_SOURCE=db
CENTRAL_MASTER_KEY_PATH=/etc/central/master.key
```
## Service Dependencies
- **central-supervisor**: Requires NATS server
- **central-archive**: Requires NATS server and PostgreSQL
## Logs
```bash
journalctl -u central-supervisor -f
journalctl -u central-archive -f
```

View file

@ -10,6 +10,7 @@ User=central
Group=central
WorkingDirectory=/opt/central
Environment=HOME=/opt/central
EnvironmentFile=/etc/central/central.env
ExecStart=/opt/central/.venv/bin/central-archive
Restart=on-failure
RestartSec=5

View file

@ -10,6 +10,7 @@ User=central
Group=central
WorkingDirectory=/opt/central
Environment=HOME=/opt/central
EnvironmentFile=/etc/central/central.env
ExecStart=/opt/central/.venv/bin/central-supervisor
Restart=on-failure
RestartSec=5

285
tests/test_config_source.py Normal file
View file

@ -0,0 +1,285 @@
"""Tests for configuration source abstraction."""
import asyncio
import base64
import os
from datetime import datetime, timezone
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_source import (
ConfigSource,
TomlConfigSource,
DbConfigSource,
create_config_source,
)
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
class TestTomlConfigSource:
"""Tests for TOML-based config source."""
@pytest.fixture
def toml_file(self, tmp_path: Path) -> Path:
"""Create a test TOML config file."""
toml_content = """
[adapters.nws]
enabled = true
cadence_s = 60
states = ["ID", "MT"]
contact_email = "test@example.com"
[adapters.disabled_adapter]
enabled = false
cadence_s = 300
states = []
contact_email = "test@example.com"
[cloudevents]
type_prefix = "central"
source = "central.local"
schema_version = "1.0"
[nats]
url = "nats://localhost:4222"
[postgres]
dsn = "postgresql://user:pass@localhost/db"
"""
path = tmp_path / "central.toml"
path.write_text(toml_content)
return path
@pytest.mark.asyncio
async def test_list_enabled_adapters(self, toml_file: Path) -> None:
"""list_enabled_adapters returns only enabled adapters."""
source = TomlConfigSource(toml_file)
adapters = await source.list_enabled_adapters()
assert len(adapters) == 1
assert adapters[0].name == "nws"
assert adapters[0].enabled is True
assert adapters[0].cadence_s == 60
@pytest.mark.asyncio
async def test_get_adapter(self, toml_file: Path) -> None:
"""get_adapter returns correct adapter config."""
source = TomlConfigSource(toml_file)
adapter = await source.get_adapter("nws")
assert adapter is not None
assert adapter.name == "nws"
assert adapter.settings["states"] == ["ID", "MT"]
assert adapter.settings["contact_email"] == "test@example.com"
@pytest.mark.asyncio
async def test_get_nonexistent_adapter(self, toml_file: Path) -> None:
"""get_adapter returns None for nonexistent adapter."""
source = TomlConfigSource(toml_file)
adapter = await source.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_watch_for_changes_returns_immediately(self, toml_file: Path) -> None:
"""watch_for_changes is a no-op for TOML source."""
source = TomlConfigSource(toml_file)
callback_called = False
async def callback(table: str, key: str) -> None:
nonlocal callback_called
callback_called = True
# Should return immediately without blocking
await asyncio.wait_for(
source.watch_for_changes(callback),
timeout=1.0,
)
assert not callback_called
@pytest.mark.asyncio
async def test_implements_protocol(self, toml_file: Path) -> None:
"""TomlConfigSource implements ConfigSource protocol."""
source = TomlConfigSource(toml_file)
assert isinstance(source, ConfigSource)
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("DELETE FROM config.adapters")
class TestDbConfigSource:
"""Tests for database-backed config source."""
@pytest_asyncio.fixture
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
"""Create a DbConfigSource for testing."""
source = await DbConfigSource.create(TEST_DB_DSN)
yield source
await source.close()
@pytest.mark.asyncio
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
"""list_enabled_adapters returns empty list when no adapters."""
adapters = await db_source.list_enabled_adapters()
assert adapters == []
@pytest.mark.asyncio
async def test_list_enabled_adapters(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""list_enabled_adapters returns only enabled, non-paused adapters."""
# Insert test adapters
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
('disabled_adapter', false, 60, '{}'::jsonb),
('paused_adapter', true, 60, '{}'::jsonb)
""")
await db_conn.execute("""
UPDATE config.adapters
SET paused_at = now()
WHERE name = 'paused_adapter'
""")
adapters = await db_source.list_enabled_adapters()
assert len(adapters) == 1
assert adapters[0].name == "enabled_adapter"
@pytest.mark.asyncio
async def test_get_adapter(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""get_adapter returns correct adapter config."""
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
""")
adapter = await db_source.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.cadence_s == 120
assert adapter.settings == {"states": ["ID"]}
@pytest.mark.asyncio
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
"""get_adapter returns None for nonexistent adapter."""
adapter = await db_source.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
"""DbConfigSource implements ConfigSource protocol."""
assert isinstance(db_source, ConfigSource)
class TestCreateConfigSource:
"""Tests for the config source factory function."""
@pytest.fixture
def toml_file(self, tmp_path: Path) -> Path:
"""Create a minimal TOML config file."""
toml_content = """
[adapters.nws]
enabled = true
cadence_s = 60
states = []
contact_email = "test@example.com"
[cloudevents]
[nats]
[postgres]
dsn = "postgresql://test@localhost/test"
"""
path = tmp_path / "central.toml"
path.write_text(toml_content)
return path
@pytest.mark.asyncio
async def test_create_toml_source(self, toml_file: Path) -> None:
"""create_config_source returns TomlConfigSource for 'toml' type."""
source = await create_config_source(
source_type="toml",
toml_path=toml_file,
)
assert isinstance(source, TomlConfigSource)
await source.close()
@pytest.mark.asyncio
async def test_create_db_source(self, clean_config_schema: None) -> None:
"""create_config_source returns DbConfigSource for 'db' type."""
source = await create_config_source(
source_type="db",
dsn=TEST_DB_DSN,
)
assert isinstance(source, DbConfigSource)
await source.close()
@pytest.mark.asyncio
async def test_create_toml_requires_path(self) -> None:
"""create_config_source raises for 'toml' without path."""
with pytest.raises(ValueError, match="toml_path required"):
await create_config_source(source_type="toml")
@pytest.mark.asyncio
async def test_create_db_requires_dsn(self) -> None:
"""create_config_source raises for 'db' without dsn."""
with pytest.raises(ValueError, match="dsn required"):
await create_config_source(source_type="db")
@pytest.mark.asyncio
async def test_create_unknown_type_raises(self) -> None:
"""create_config_source raises for unknown type."""
with pytest.raises(ValueError, match="Unknown config source type"):
await create_config_source(source_type="unknown")

View file

@ -0,0 +1,394 @@
"""Tests for supervisor hot-reload and rate-limiting behavior."""
import asyncio
import base64
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import asyncpg
import pytest
import pytest_asyncio
from central.config_models import AdapterConfig
from central.config_source import DbConfigSource
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
# Create notify trigger
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("DELETE FROM config.adapters")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestDbConfigSourceNotifications:
"""Tests for DbConfigSource NOTIFY integration."""
@pytest.mark.asyncio
async def test_watch_receives_notifications(
self,
config_store: ConfigStore,
db_conn: asyncpg.Connection,
) -> None:
"""watch_for_changes receives NOTIFY when adapter changes."""
source = DbConfigSource(config_store)
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start watching in background
watch_task = asyncio.create_task(source.watch_for_changes(callback))
try:
# Wait for listener to connect
await asyncio.sleep(0.2)
# Insert an adapter via direct connection (not through store)
# This triggers the NOTIFY
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 60, '{}'::jsonb)
""")
# Wait for notification
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
assert len(notifications) >= 1
assert notifications[0] == ("adapters", "test_adapter")
finally:
watch_task.cancel()
try:
await watch_task
except asyncio.CancelledError:
pass
class TestRateLimitGuarantee:
"""Tests for rate-limit guarantees during hot-reload.
These tests verify the critical invariant: cadence changes must not
cause extra API calls before (last_poll + new_cadence).
"""
@pytest.mark.asyncio
async def test_cadence_change_respects_last_poll_time(self) -> None:
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
This is the core rate-limit guarantee test (gate 3).
"""
# Import supervisor module to access AdapterState
from central.supervisor import AdapterState
# Mock adapter
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Create adapter state with a known last_completed_poll time
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60, # Original cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Simulate cadence change to 90 seconds
new_config = AdapterConfig(
name="test",
enabled=True,
cadence_s=90, # New cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
# Update state as reschedule would
state.config = new_config
state.adapter.cadence_s = 90
# Calculate expected next poll time
expected_next_poll = last_poll + timedelta(seconds=90)
now = datetime.now(timezone.utc)
expected_wait = max(0, (expected_next_poll - now).total_seconds())
# The wait time should be based on last_poll + new_cadence
# Since last_poll was 30 seconds ago and new cadence is 90,
# we should wait 60 more seconds (90 - 30 = 60)
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
actual_wait = max(0, actual_next_poll - now.timestamp())
# Allow 1 second tolerance for timing
assert abs(actual_wait - 60) < 2, (
f"Expected ~60s wait, got {actual_wait}s. "
f"Rate limit violated: poll would happen before last_poll + new_cadence"
)
@pytest.mark.asyncio
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
"""When last_poll + new_cadence is already past, poll immediately.
If operator increases cadence to 120s after a gap of 150s,
the poll should happen now (not wait another 120s).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 150 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=120, # Increased cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
assert wait_time < 1, (
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
f"After a gap exceeding new cadence, poll should happen now."
)
@pytest.mark.asyncio
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
"""Re-enabling adapter schedules poll at last_poll + cadence.
If adapter was disabled for a while and then re-enabled, the next
poll should be at (last_completed_poll + cadence_s), not immediately
(unless that time has already passed).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 30 seconds ago, then adapter was disabled
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
# Re-enabled config
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~30 more seconds (60 - 30 = 30)
assert abs(wait_time - 30) < 2, (
f"Expected ~30s wait after re-enable, got {wait_time}s. "
f"Rate limit violated on enable→disable→enable sequence."
)
@pytest.mark.asyncio
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
"""Multiple rapid cadence changes don't cause extra polls.
If NOTIFY fires rapidly (609012090), the final schedule should
still be based on last_completed_poll + final_cadence.
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 20 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
),
last_completed_poll=last_poll,
)
# Simulate rapid cadence changes
for cadence in [90, 120, 90]: # Final cadence is 90
state.config = AdapterConfig(
name="test",
enabled=True,
cadence_s=cadence,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state.adapter.cadence_s = cadence
# Final schedule should be last_poll + 90
now = datetime.now(timezone.utc)
final_cadence = 90
next_poll_at = last_poll.timestamp() + final_cadence
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~70 seconds (90 - 20 = 70)
assert abs(wait_time - 70) < 2, (
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
f"Multiple NOTIFYs should not cause extra polls."
)
class TestBootstrapConfigFlag:
"""Tests for CENTRAL_CONFIG_SOURCE bootstrap flag."""
def test_default_is_toml(self) -> None:
"""Default config_source is 'toml'."""
from central.bootstrap_config import Settings
# Create settings with minimal required fields
settings = Settings(
db_dsn="postgresql://test@localhost/test",
_env_file=None,
)
assert settings.config_source == "toml"
def test_accepts_db(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""config_source accepts 'db' value."""
from central.bootstrap_config import Settings, get_settings
get_settings.cache_clear()
monkeypatch.setenv("CENTRAL_CONFIG_SOURCE", "db")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test@localhost/test")
settings = get_settings()
assert settings.config_source == "db"
def test_rejects_invalid(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""config_source rejects invalid values."""
from pydantic import ValidationError
from central.bootstrap_config import Settings, get_settings
get_settings.cache_clear()
monkeypatch.setenv("CENTRAL_CONFIG_SOURCE", "invalid")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test@localhost/test")
with pytest.raises(ValidationError):
get_settings()

View file

@ -0,0 +1,546 @@
"""Integration tests for Supervisor hot-reload with enable/disable/enable flow.
These tests exercise the actual Supervisor._on_config_change code path,
not just AdapterState math in isolation. They verify the rate-limit
guarantee is maintained across adapter stop/start cycles.
IMPORTANT: These tests are designed to:
- FAIL on unfixed code (Test B fails because last_completed_poll is lost)
- PASS on fixed code (last_completed_poll is preserved across disable/enable)
"""
import asyncio
import base64
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from central.config_models import AdapterConfig
from central.crypto import KEY_SIZE, clear_key_cache
def adapter_is_running(state) -> bool:
"""Check if adapter is running (compatible with both fixed and unfixed code)."""
# Fixed code has is_running property; unfixed checks task directly
if hasattr(state, 'is_running'):
return state.is_running
return state.task is not None and not state.task.done()
async def cleanup_adapter(supervisor, name: str) -> None:
"""Clean up adapter (compatible with both fixed and unfixed code)."""
# Fixed code has _remove_adapter; unfixed uses _stop_adapter which pops
if hasattr(supervisor, '_remove_adapter'):
await supervisor._remove_adapter(name)
else:
await supervisor._stop_adapter(name)
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
class MockConfigSource:
"""Mock ConfigSource for testing Supervisor without DB."""
def __init__(self) -> None:
self._adapters: dict[str, AdapterConfig] = {}
def set_adapter(self, config: AdapterConfig | None, name: str | None = None) -> None:
"""Set or remove an adapter config."""
if config is None:
if name:
self._adapters.pop(name, None)
else:
self._adapters[config.name] = config
async def list_enabled_adapters(self) -> list[AdapterConfig]:
return [a for a in self._adapters.values() if a.enabled and not a.is_paused]
async def get_adapter(self, name: str) -> AdapterConfig | None:
return self._adapters.get(name)
async def watch_for_changes(self, callback) -> None:
# No-op for testing
return
async def close(self) -> None:
pass
class MockNWSAdapter:
"""Mock NWSAdapter that tracks poll calls and allows control."""
def __init__(self, config, cursor_db_path) -> None:
self.config = config
self.cadence_s = config.cadence_s
self.states = set(s.upper() for s in config.states)
self.poll_count = 0
self.poll_times: list[datetime] = []
self._shutdown = False
async def startup(self) -> None:
pass
async def shutdown(self) -> None:
self._shutdown = True
async def poll(self):
"""Yield nothing - we just track that poll was called."""
self.poll_count += 1
self.poll_times.append(datetime.now(timezone.utc))
return
yield # Make this an async generator
def is_published(self, event_id: str) -> bool:
return False
def mark_published(self, event_id: str) -> None:
pass
def bump_last_seen(self, event_id: str) -> None:
pass
def sweep_old_ids(self) -> int:
return 0
@pytest.fixture
def mock_nats():
"""Mock NATS connection."""
mock_nc = AsyncMock()
mock_nc.publish = AsyncMock()
mock_js = AsyncMock()
mock_js.publish = AsyncMock()
mock_nc.jetstream.return_value = mock_js
return mock_nc
class TestEnableDisableEnableIntegration:
"""Integration tests for enable→disable→enable flow through Supervisor.
These tests verify that _on_config_change _stop_adapter _start_adapter
preserves last_completed_poll correctly.
"""
@pytest.mark.asyncio
async def test_enable_disable_enable_gap_longer_than_cadence(
self, mock_nats, tmp_path: Path
) -> None:
"""Test A: Re-enable after gap longer than cadence polls immediately.
- Start adapter (cadence 60s)
- Simulate completed poll 5 minutes ago
- Disable adapter
- Re-enable adapter
- Assert next poll fires immediately (last+cadence is in past)
- Assert exactly ONE poll happens, not multiple catch-up
"""
from central.supervisor import Supervisor, AdapterState
config_source = MockConfigSource()
initial_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(initial_config)
supervisor = Supervisor(
config_source=config_source,
nats_url="nats://localhost:4222",
cloudevents_config=None,
)
# Mock NATS connection
supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream()
# Patch NWSAdapter to use our mock
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
# Start supervisor (starts adapter)
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws")
assert state is not None
assert adapter_is_running(state)
# Simulate completed poll 5 minutes ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(minutes=5)
saved_last_poll = state.last_completed_poll
# Disable adapter
disabled_config = AdapterConfig(
name="nws",
enabled=False,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws")
# Verify stopped but state preserved (THIS IS THE KEY CHECK)
# On unfixed code, state will be NONE because pop() removes it
# On fixed code, state still exists with is_running=False
state = supervisor._adapter_states.get("nws")
assert state is not None, (
"State was removed on stop! This violates the rate-limit guarantee. "
"State should be preserved to maintain last_completed_poll."
)
assert not adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll
# Re-enable adapter
reenabled_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(reenabled_config)
await supervisor._on_config_change("adapters", "nws")
# Verify restarted with preserved last_completed_poll
state = supervisor._adapter_states.get("nws")
assert state is not None
assert adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll
# The loop should detect that last_poll + cadence is in the past
# and poll immediately. Let's verify by checking the wait time logic.
now = datetime.now(timezone.utc)
next_poll_at = saved_last_poll.timestamp() + 60 # cadence = 60s
wait_time = max(0, next_poll_at - now.timestamp())
# last_poll was 5 minutes ago, cadence is 60s
# next_poll_at = 5_minutes_ago + 60s = 4_minutes_ago
# wait_time should be 0 (poll immediately)
assert wait_time == 0, (
f"Expected immediate poll (wait=0), got wait={wait_time}s. "
f"last_poll was {saved_last_poll}, now is {now}"
)
# Cleanup
supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio
async def test_enable_disable_enable_gap_shorter_than_cadence(
self, mock_nats, tmp_path: Path
) -> None:
"""Test B: Re-enable after gap shorter than cadence respects rate limit.
THIS IS THE KEY TEST that failed before the fix.
- Start adapter (cadence 60s)
- Simulate completed poll 10 seconds ago
- Disable adapter
- Re-enable adapter 20 seconds later (still within cadence window)
- Assert next poll fires at last_poll + 60s, NOT immediately
"""
from central.supervisor import Supervisor, AdapterState
config_source = MockConfigSource()
initial_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(initial_config)
supervisor = Supervisor(
config_source=config_source,
nats_url="nats://localhost:4222",
cloudevents_config=None,
)
supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
# Start adapter
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws")
assert state is not None
# Simulate completed poll 10 seconds ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
saved_last_poll = state.last_completed_poll
# Disable adapter
disabled_config = AdapterConfig(
name="nws",
enabled=False,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws")
# Verify stopped but state preserved (THIS IS THE KEY CHECK)
# On unfixed code, state will be NONE because pop() removes it
# On fixed code, state still exists with is_running=False
state = supervisor._adapter_states.get("nws")
assert state is not None, (
"State was removed on stop! This violates the rate-limit guarantee. "
"State should be preserved to maintain last_completed_poll."
)
assert not adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll
# Re-enable adapter (simulate 20 seconds later, but we're just
# checking the rate limit logic)
reenabled_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(reenabled_config)
await supervisor._on_config_change("adapters", "nws")
# Verify restarted with preserved last_completed_poll
state = supervisor._adapter_states.get("nws")
assert state is not None
assert adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll
# The loop should detect that last_poll + cadence is still in the future
# and wait until then.
now = datetime.now(timezone.utc)
next_poll_at = saved_last_poll.timestamp() + 60
wait_time = max(0, next_poll_at - now.timestamp())
# last_poll was ~10 seconds ago, cadence is 60s
# wait_time should be ~50s (60 - 10 = 50)
assert 45 < wait_time < 55, (
f"Expected ~50s wait (respecting rate limit), got wait={wait_time}s. "
f"Rate limit violated: poll would happen before last_poll + cadence"
)
# Cleanup
supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio
async def test_enable_disable_delete_readd_fresh_state(
self, mock_nats, tmp_path: Path
) -> None:
"""Test C: Delete then re-add clears preserved state.
- Start adapter
- Simulate completed poll
- Disable adapter
- DELETE adapter from DB (not just disable)
- Re-add adapter with same name
- Assert preserved timestamp is dropped (fresh adapter, immediate poll)
"""
from central.supervisor import Supervisor
config_source = MockConfigSource()
initial_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(initial_config)
supervisor = Supervisor(
config_source=config_source,
nats_url="nats://localhost:4222",
cloudevents_config=None,
)
supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
# Start adapter
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws")
assert state is not None
# Simulate completed poll 10 seconds ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
# Disable adapter
disabled_config = AdapterConfig(
name="nws",
enabled=False,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws")
# DELETE adapter from DB (remove from config source)
config_source.set_adapter(None, name="nws")
await supervisor._on_config_change("adapters", "nws")
# Verify adapter fully removed
assert "nws" not in supervisor._adapter_states
# Re-add adapter with same name
new_config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(new_config)
await supervisor._on_config_change("adapters", "nws")
# Verify new adapter started fresh
state = supervisor._adapter_states.get("nws")
assert state is not None
assert adapter_is_running(state)
# last_completed_poll should be None (fresh adapter)
assert state.last_completed_poll is None, (
f"Expected None (fresh adapter), got {state.last_completed_poll}. "
f"Preserved state not cleared on delete."
)
# Cleanup
supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio
async def test_stop_preserves_state_start_reuses_it(
self, mock_nats, tmp_path: Path
) -> None:
"""Verify _stop_adapter preserves state and _start_adapter reuses it."""
from central.supervisor import Supervisor
config_source = MockConfigSource()
config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(config)
supervisor = Supervisor(
config_source=config_source,
nats_url="nats://localhost:4222",
cloudevents_config=None,
)
supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
# Start adapter
await supervisor._start_adapter(config)
state = supervisor._adapter_states.get("nws")
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
saved_poll = state.last_completed_poll
# Stop adapter
await supervisor._stop_adapter("nws")
# State should still exist
assert "nws" in supervisor._adapter_states
state = supervisor._adapter_states["nws"]
assert not adapter_is_running(state)
assert state.last_completed_poll == saved_poll
# Restart adapter
await supervisor._start_adapter(config)
# Should reuse existing state
state = supervisor._adapter_states.get("nws")
assert adapter_is_running(state)
assert state.last_completed_poll == saved_poll
# Cleanup
supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio
async def test_remove_adapter_clears_state(
self, mock_nats, tmp_path: Path
) -> None:
"""Verify _remove_adapter fully clears state."""
from central.supervisor import Supervisor
config_source = MockConfigSource()
config = AdapterConfig(
name="nws",
enabled=True,
cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
config_source.set_adapter(config)
supervisor = Supervisor(
config_source=config_source,
nats_url="nats://localhost:4222",
cloudevents_config=None,
)
supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
await supervisor._start_adapter(config)
state = supervisor._adapter_states.get("nws")
state.last_completed_poll = datetime.now(timezone.utc)
# Remove adapter
await cleanup_adapter(supervisor, "nws")
# State should be gone
assert "nws" not in supervisor._adapter_states