mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-22 02:24:38 +02:00
Merge pull request #2 from zvx-echo6/feature/1a-service-cutover
feat(config): Phase 1a-3 service cutover to DB-backed config
This commit is contained in:
commit
b3788d556d
12 changed files with 2501 additions and 604 deletions
|
|
@ -13,9 +13,8 @@ import nats
|
|||
from nats.js import JetStreamContext
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
|
||||
|
||||
from central.config import load_config, Config
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
CONFIG_PATH = "/etc/central/central.toml"
|
||||
CONSUMER_NAME = "archive"
|
||||
STREAM_NAME = "CENTRAL_WX"
|
||||
SUBJECT_FILTER = "central.wx.>"
|
||||
|
|
@ -93,8 +92,9 @@ def _build_geom_sql(geo_data: dict[str, Any] | None) -> str | None:
|
|||
class ArchiveConsumer:
|
||||
"""Archive consumer process."""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
def __init__(self, nats_url: str, postgres_dsn: str) -> None:
|
||||
self._nats_url = nats_url
|
||||
self._postgres_dsn = postgres_dsn
|
||||
self._nc: nats.NATS | None = None
|
||||
self._js: JetStreamContext | None = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
|
|
@ -102,12 +102,12 @@ class ArchiveConsumer:
|
|||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to NATS and PostgreSQL."""
|
||||
self._nc = await nats.connect(self.config.nats.url)
|
||||
self._nc = await nats.connect(self._nats_url)
|
||||
self._js = self._nc.jetstream()
|
||||
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
|
||||
logger.info("Connected to NATS", extra={"url": self._nats_url})
|
||||
|
||||
self._pool = await asyncpg.create_pool(
|
||||
self.config.postgres.dsn,
|
||||
self._postgres_dsn,
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
|
|
@ -303,8 +303,19 @@ async def async_main() -> None:
|
|||
"""Async entry point."""
|
||||
setup_logging()
|
||||
|
||||
config = load_config(CONFIG_PATH)
|
||||
consumer = ArchiveConsumer(config)
|
||||
settings = get_settings()
|
||||
logger.info(
|
||||
"Archive starting",
|
||||
extra={
|
||||
"nats_url": settings.nats_url,
|
||||
"config_source": settings.config_source,
|
||||
},
|
||||
)
|
||||
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url=settings.nats_url,
|
||||
postgres_dsn=settings.db_dsn,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
|
|
@ -33,6 +33,21 @@ class Settings(BaseSettings):
|
|||
default="INFO",
|
||||
description="Logging level",
|
||||
)
|
||||
config_source: Literal["toml", "db"] = Field(
|
||||
default="toml",
|
||||
description="Configuration source: 'toml' for TOML file, 'db' for database",
|
||||
)
|
||||
config_toml_path: Path = Field(
|
||||
default=Path("/etc/central/central.toml"),
|
||||
description="Path to TOML config file (when config_source=toml)",
|
||||
)
|
||||
|
||||
@field_validator("config_source")
|
||||
@classmethod
|
||||
def validate_config_source(cls, v: str) -> str:
|
||||
if v not in ("toml", "db"):
|
||||
raise ValueError(f"config_source must be 'toml' or 'db', got {v!r}")
|
||||
return v
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
|
|
|||
15
src/central/cloudevents_constants.py
Normal file
15
src/central/cloudevents_constants.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""CloudEvents configuration constants.
|
||||
|
||||
These are the protocol-level constants for CloudEvents envelope format.
|
||||
CloudEvents envelope format is part of the Central protocol contract
|
||||
and is not operator-configurable.
|
||||
"""
|
||||
|
||||
from central.config import CloudEventsConfig
|
||||
|
||||
# CloudEvents protocol constants
|
||||
CLOUDEVENTS_CONFIG = CloudEventsConfig(
|
||||
type_prefix="central",
|
||||
source="central.echo6.co",
|
||||
schema_version="1.0",
|
||||
)
|
||||
|
|
@ -1,30 +1,48 @@
|
|||
"""CloudEvents wire format helpers."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from cloudevents.v1.http import CloudEvent
|
||||
|
||||
from central.config import Config
|
||||
from central.config import Config, CloudEventsConfig
|
||||
from central.cloudevents_constants import CLOUDEVENTS_CONFIG
|
||||
from central.models import Event
|
||||
|
||||
|
||||
def wrap_event(event: Event, config: Config) -> tuple[dict[str, Any], str]:
|
||||
def wrap_event(
|
||||
event: Event,
|
||||
config: Union[Config, CloudEventsConfig, None] = None,
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
"""
|
||||
Wrap an Event into a CNCF CloudEvents v1.0 JSON envelope.
|
||||
|
||||
Args:
|
||||
event: The event to wrap
|
||||
config: Either a full Config object, a CloudEventsConfig object,
|
||||
or None to use defaults.
|
||||
|
||||
Returns:
|
||||
A tuple of (envelope_dict, msg_id) where msg_id is the
|
||||
CloudEvent id for use as Nats-Msg-Id header.
|
||||
"""
|
||||
# Resolve CloudEventsConfig from various input types
|
||||
if config is None:
|
||||
ce_config = CLOUDEVENTS_CONFIG
|
||||
elif isinstance(config, CloudEventsConfig):
|
||||
ce_config = config
|
||||
else:
|
||||
# It's a full Config object
|
||||
ce_config = config.cloudevents
|
||||
|
||||
# Build CE type: {prefix}.{category}.v1
|
||||
ce_type = f"{config.cloudevents.type_prefix}.{event.category}.v1"
|
||||
ce_type = f"{ce_config.type_prefix}.{event.category}.v1"
|
||||
|
||||
# Serialize event data
|
||||
event_data = event.model_dump(mode="json")
|
||||
|
||||
# Build extension attributes - lowercase, no underscores per CE spec
|
||||
extensions: dict[str, Any] = {
|
||||
"centralschemaversion": config.cloudevents.schema_version,
|
||||
"centralschemaversion": ce_config.schema_version,
|
||||
"centralcategory": event.category,
|
||||
}
|
||||
|
||||
|
|
@ -36,7 +54,7 @@ def wrap_event(event: Event, config: Config) -> tuple[dict[str, Any], str]:
|
|||
ce = CloudEvent(
|
||||
attributes={
|
||||
"id": event.id,
|
||||
"source": config.cloudevents.source,
|
||||
"source": ce_config.source,
|
||||
"type": ce_type,
|
||||
"time": event.time.isoformat(),
|
||||
"datacontenttype": "application/json",
|
||||
|
|
|
|||
187
src/central/config_source.py
Normal file
187
src/central/config_source.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Configuration source abstraction.
|
||||
|
||||
Provides a unified interface for loading adapter configuration from
|
||||
either TOML files or the database-backed config store.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import tomllib
|
||||
|
||||
from central.config import NWSAdapterConfig
|
||||
from central.config_models import AdapterConfig
|
||||
from central.config_store import ConfigStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConfigSource(Protocol):
|
||||
"""Protocol for configuration sources."""
|
||||
|
||||
async def list_enabled_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all enabled adapters."""
|
||||
...
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get configuration for a specific adapter."""
|
||||
...
|
||||
|
||||
async def watch_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""Watch for configuration changes.
|
||||
|
||||
For TOML source, this is a no-op (returns immediately).
|
||||
For DB source, this runs forever, calling callback(table, key) on changes.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
...
|
||||
|
||||
|
||||
class TomlConfigSource:
|
||||
"""Configuration source backed by a TOML file.
|
||||
|
||||
This is the legacy configuration path. Does not support hot-reload.
|
||||
"""
|
||||
|
||||
def __init__(self, toml_path: Path) -> None:
|
||||
self._toml_path = toml_path
|
||||
self._adapters: dict[str, AdapterConfig] = {}
|
||||
self._loaded = False
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load configuration from TOML file."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
with self._toml_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
adapters_raw = data.get("adapters", {})
|
||||
from datetime import datetime, timezone
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for name, adapter_data in adapters_raw.items():
|
||||
# Convert TOML adapter config to unified AdapterConfig
|
||||
# TOML uses NWSAdapterConfig shape, we need to convert to AdapterConfig
|
||||
enabled = adapter_data.get("enabled", True)
|
||||
cadence_s = adapter_data.get("cadence_s", 60)
|
||||
|
||||
# Extract settings (everything except enabled/cadence_s)
|
||||
settings = {
|
||||
k: v
|
||||
for k, v in adapter_data.items()
|
||||
if k not in ("enabled", "cadence_s")
|
||||
}
|
||||
|
||||
self._adapters[name] = AdapterConfig(
|
||||
name=name,
|
||||
enabled=enabled,
|
||||
cadence_s=cadence_s,
|
||||
settings=settings,
|
||||
paused_at=None,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
self._loaded = True
|
||||
logger.info(
|
||||
"Loaded TOML config",
|
||||
extra={"path": str(self._toml_path), "adapters": list(self._adapters.keys())},
|
||||
)
|
||||
|
||||
async def list_enabled_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all enabled adapters from TOML."""
|
||||
self._load()
|
||||
return [a for a in self._adapters.values() if a.enabled and not a.is_paused]
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get a specific adapter from TOML."""
|
||||
self._load()
|
||||
return self._adapters.get(name)
|
||||
|
||||
async def watch_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""TOML does not support hot-reload. Returns immediately."""
|
||||
logger.debug("TOML config source does not support hot-reload")
|
||||
return
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No resources to clean up for TOML source."""
|
||||
pass
|
||||
|
||||
|
||||
class DbConfigSource:
|
||||
"""Configuration source backed by the Postgres config store.
|
||||
|
||||
Supports hot-reload via LISTEN/NOTIFY.
|
||||
"""
|
||||
|
||||
def __init__(self, config_store: ConfigStore) -> None:
|
||||
self._store = config_store
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dsn: str) -> "DbConfigSource":
|
||||
"""Create a DbConfigSource with a new ConfigStore."""
|
||||
store = await ConfigStore.create(dsn)
|
||||
return cls(store)
|
||||
|
||||
async def list_enabled_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all enabled adapters from database."""
|
||||
all_adapters = await self._store.list_adapters()
|
||||
return [a for a in all_adapters if a.enabled and not a.is_paused]
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get a specific adapter from database."""
|
||||
return await self._store.get_adapter(name)
|
||||
|
||||
async def watch_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""Watch for changes via Postgres LISTEN/NOTIFY.
|
||||
|
||||
Runs forever, calling callback(table, key) on each change.
|
||||
"""
|
||||
await self._store.listen_for_changes(callback)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying config store."""
|
||||
await self._store.close()
|
||||
|
||||
|
||||
async def create_config_source(
|
||||
source_type: str,
|
||||
dsn: str | None = None,
|
||||
toml_path: Path | None = None,
|
||||
) -> ConfigSource:
|
||||
"""Factory function to create the appropriate config source.
|
||||
|
||||
Args:
|
||||
source_type: "toml" or "db"
|
||||
dsn: PostgreSQL DSN (required for "db")
|
||||
toml_path: Path to TOML file (required for "toml")
|
||||
|
||||
Returns:
|
||||
ConfigSource implementation
|
||||
"""
|
||||
if source_type == "toml":
|
||||
if toml_path is None:
|
||||
raise ValueError("toml_path required for toml config source")
|
||||
return TomlConfigSource(toml_path)
|
||||
elif source_type == "db":
|
||||
if dsn is None:
|
||||
raise ValueError("dsn required for db config source")
|
||||
return await DbConfigSource.create(dsn)
|
||||
else:
|
||||
raise ValueError(f"Unknown config source type: {source_type}")
|
||||
|
|
@ -5,6 +5,7 @@ import json
|
|||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -14,11 +15,13 @@ from nats.js import JetStreamContext
|
|||
|
||||
from central.adapters.nws import NWSAdapter
|
||||
from central.cloudevents_wire import wrap_event
|
||||
from central.config import load_config, Config
|
||||
from central.config import load_config, Config, NWSAdapterConfig
|
||||
from central.config_models import AdapterConfig
|
||||
from central.config_source import ConfigSource, create_config_source
|
||||
from central.bootstrap_config import get_settings
|
||||
from central.models import subject_for_event
|
||||
|
||||
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
|
||||
CONFIG_PATH = "/etc/central/central.toml"
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
|
|
@ -35,7 +38,6 @@ class JsonFormatter(logging.Formatter):
|
|||
log_obj["exc"] = self.formatException(record.exc_info)
|
||||
if hasattr(record, "extra"):
|
||||
log_obj.update(record.extra)
|
||||
# Include any extra fields passed via extra={}
|
||||
for key in record.__dict__:
|
||||
if key not in (
|
||||
"name", "msg", "args", "created", "filename", "funcName",
|
||||
|
|
@ -59,23 +61,49 @@ def setup_logging() -> None:
|
|||
logger = logging.getLogger("central.supervisor")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterState:
|
||||
"""Runtime state for a scheduled adapter."""
|
||||
|
||||
name: str
|
||||
adapter: NWSAdapter
|
||||
config: AdapterConfig
|
||||
task: asyncio.Task[None] | None = None
|
||||
last_completed_poll: datetime | None = None
|
||||
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if adapter loop is currently running."""
|
||||
return self.task is not None and not self.task.done()
|
||||
|
||||
|
||||
class Supervisor:
|
||||
"""Main supervisor process."""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
def __init__(
|
||||
self,
|
||||
config_source: ConfigSource,
|
||||
nats_url: str,
|
||||
cloudevents_config: Any = None,
|
||||
) -> None:
|
||||
self._config_source = config_source
|
||||
self._nats_url = nats_url
|
||||
self._cloudevents_config = cloudevents_config
|
||||
self._nc: nats.NATS | None = None
|
||||
self._js: JetStreamContext | None = None
|
||||
self._adapters: list[NWSAdapter] = []
|
||||
self._adapter_states: dict[str, AdapterState] = {}
|
||||
self._tasks: list[asyncio.Task[None]] = []
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._start_time = datetime.now(timezone.utc)
|
||||
self._config_watch_task: asyncio.Task[None] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to NATS."""
|
||||
self._nc = await nats.connect(self.config.nats.url)
|
||||
self._nc = await nats.connect(self._nats_url)
|
||||
self._js = self._nc.jetstream()
|
||||
logger.info("Connected to NATS", extra={"url": self.config.nats.url})
|
||||
logger.info("Connected to NATS", extra={"url": self._nats_url})
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from NATS."""
|
||||
|
|
@ -104,24 +132,82 @@ class Supervisor:
|
|||
headers={"Nats-Msg-Id": msg_id},
|
||||
)
|
||||
|
||||
async def _run_adapter(self, adapter: NWSAdapter) -> None:
|
||||
"""Run an adapter poll loop."""
|
||||
def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig:
|
||||
"""Convert unified AdapterConfig to NWSAdapterConfig."""
|
||||
return NWSAdapterConfig(
|
||||
enabled=config.enabled,
|
||||
cadence_s=config.cadence_s,
|
||||
states=config.settings.get("states", []),
|
||||
contact_email=config.settings.get("contact_email", ""),
|
||||
)
|
||||
|
||||
async def _run_adapter_loop(self, state: AdapterState) -> None:
|
||||
"""Run an adapter poll loop with rate-limit aware scheduling."""
|
||||
while not self._shutdown_event.is_set():
|
||||
# Calculate next poll time based on rate-limit guarantee
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if state.last_completed_poll is not None:
|
||||
next_poll_at = state.last_completed_poll.timestamp() + state.config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
else:
|
||||
# First poll - run immediately
|
||||
wait_time = 0
|
||||
|
||||
if wait_time > 0:
|
||||
logger.debug(
|
||||
"Waiting for next poll",
|
||||
extra={
|
||||
"adapter": state.name,
|
||||
"wait_s": wait_time,
|
||||
"next_poll": datetime.fromtimestamp(
|
||||
now.timestamp() + wait_time, tz=timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
# Wait for either timeout or cancel signal
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
state.cancel_event.wait(),
|
||||
timeout=wait_time,
|
||||
)
|
||||
# Cancel event was set - check if we should exit or reschedule
|
||||
if self._shutdown_event.is_set():
|
||||
break
|
||||
# Clear the cancel event and re-evaluate schedule
|
||||
state.cancel_event.clear()
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Check shutdown before polling
|
||||
if self._shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Check if adapter is still enabled
|
||||
if not state.config.enabled or state.config.is_paused:
|
||||
logger.info(
|
||||
"Adapter disabled/paused, stopping loop",
|
||||
extra={"adapter": state.name},
|
||||
)
|
||||
break
|
||||
|
||||
poll_start = datetime.now(timezone.utc)
|
||||
try:
|
||||
async for event in adapter.poll():
|
||||
async for event in state.adapter.poll():
|
||||
# Dedup check
|
||||
if adapter.is_published(event.id):
|
||||
adapter.bump_last_seen(event.id)
|
||||
if state.adapter.is_published(event.id):
|
||||
state.adapter.bump_last_seen(event.id)
|
||||
continue
|
||||
|
||||
# Build CloudEvent
|
||||
envelope, msg_id = wrap_event(event, self.config)
|
||||
# Build CloudEvent (uses defaults if no config provided)
|
||||
envelope, msg_id = wrap_event(event, self._cloudevents_config)
|
||||
|
||||
subject = subject_for_event(event)
|
||||
|
||||
# Publish
|
||||
await self._publish_event(subject, envelope, msg_id)
|
||||
adapter.mark_published(event.id)
|
||||
state.adapter.mark_published(event.id)
|
||||
|
||||
logger.info(
|
||||
"Published event",
|
||||
|
|
@ -130,36 +216,292 @@ class Supervisor:
|
|||
|
||||
# Publish success status
|
||||
await self._publish_meta(
|
||||
f"central.meta.adapter.{adapter.name}.status",
|
||||
f"central.meta.adapter.{state.name}.status",
|
||||
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
|
||||
)
|
||||
|
||||
# Mark poll completion time for rate limiting
|
||||
state.last_completed_poll = datetime.now(timezone.utc)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Adapter poll failed", extra={"adapter": adapter.name})
|
||||
logger.exception("Adapter poll failed", extra={"adapter": state.name})
|
||||
await self._publish_meta(
|
||||
f"central.meta.adapter.{adapter.name}.status",
|
||||
f"central.meta.adapter.{state.name}.status",
|
||||
{
|
||||
"ok": False,
|
||||
"error": str(e),
|
||||
"ts": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
)
|
||||
# Still mark completion time to avoid tight retry loops
|
||||
state.last_completed_poll = datetime.now(timezone.utc)
|
||||
|
||||
# Sweep old IDs
|
||||
swept = adapter.sweep_old_ids()
|
||||
swept = state.adapter.sweep_old_ids()
|
||||
if swept > 0:
|
||||
logger.info("Swept old published IDs", extra={"count": swept})
|
||||
|
||||
# Sleep until next cadence
|
||||
elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds()
|
||||
sleep_time = max(0, adapter.cadence_s - elapsed)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_event.wait(),
|
||||
timeout=sleep_time
|
||||
async def _start_adapter(self, config: AdapterConfig) -> None:
|
||||
"""Start an adapter based on its configuration.
|
||||
|
||||
If the adapter was previously stopped (state exists but task is not running),
|
||||
reuses the existing state to preserve last_completed_poll for rate limiting.
|
||||
"""
|
||||
existing_state = self._adapter_states.get(config.name)
|
||||
|
||||
if existing_state is not None:
|
||||
if existing_state.is_running:
|
||||
logger.warning(
|
||||
"Adapter already running",
|
||||
extra={"adapter": config.name},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
# Adapter was stopped - restart with preserved state
|
||||
# Update config and restart the adapter
|
||||
existing_state.config = config
|
||||
existing_state.cancel_event.clear()
|
||||
|
||||
# Reinitialize the adapter with new config
|
||||
nws_config = self._adapter_config_to_nws_config(config)
|
||||
existing_state.adapter = NWSAdapter(
|
||||
config=nws_config,
|
||||
cursor_db_path=CURSOR_DB_PATH,
|
||||
)
|
||||
await existing_state.adapter.startup()
|
||||
|
||||
# Start the loop task
|
||||
existing_state.task = asyncio.create_task(
|
||||
self._run_adapter_loop(existing_state)
|
||||
)
|
||||
|
||||
# Calculate next poll time for logging
|
||||
if existing_state.last_completed_poll:
|
||||
next_poll_at = datetime.fromtimestamp(
|
||||
existing_state.last_completed_poll.timestamp() + config.cadence_s,
|
||||
tz=timezone.utc,
|
||||
)
|
||||
if next_poll_at <= datetime.now(timezone.utc):
|
||||
next_poll_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
next_poll_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.info(
|
||||
"Adapter restarted",
|
||||
extra={
|
||||
"adapter": config.name,
|
||||
"cadence_s": config.cadence_s,
|
||||
"preserved_last_poll": existing_state.last_completed_poll.isoformat()
|
||||
if existing_state.last_completed_poll
|
||||
else None,
|
||||
"next_poll": next_poll_at.isoformat(),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# New adapter - create fresh state
|
||||
if config.name == "nws":
|
||||
nws_config = self._adapter_config_to_nws_config(config)
|
||||
adapter = NWSAdapter(
|
||||
config=nws_config,
|
||||
cursor_db_path=CURSOR_DB_PATH,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
state = AdapterState(
|
||||
name=config.name,
|
||||
adapter=adapter,
|
||||
config=config,
|
||||
)
|
||||
state.task = asyncio.create_task(self._run_adapter_loop(state))
|
||||
self._adapter_states[config.name] = state
|
||||
|
||||
logger.info(
|
||||
"Adapter started",
|
||||
extra={
|
||||
"adapter": config.name,
|
||||
"cadence_s": config.cadence_s,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown adapter type",
|
||||
extra={"adapter": config.name},
|
||||
)
|
||||
|
||||
async def _stop_adapter(self, name: str) -> None:
|
||||
"""Stop a running adapter but preserve state for potential restart.
|
||||
|
||||
The adapter state (including last_completed_poll) is preserved so that
|
||||
if the adapter is re-enabled, the rate-limit guarantee is maintained.
|
||||
Use _remove_adapter() to fully remove an adapter from tracking.
|
||||
"""
|
||||
state = self._adapter_states.get(name)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
if not state.is_running:
|
||||
# Already stopped
|
||||
return
|
||||
|
||||
# Signal the loop to stop
|
||||
state.cancel_event.set()
|
||||
|
||||
if state.task:
|
||||
state.task.cancel()
|
||||
try:
|
||||
await state.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
state.task = None
|
||||
|
||||
await state.adapter.shutdown()
|
||||
logger.info(
|
||||
"Adapter stopped",
|
||||
extra={
|
||||
"adapter": name,
|
||||
"preserved_last_poll": state.last_completed_poll.isoformat()
|
||||
if state.last_completed_poll
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
async def _remove_adapter(self, name: str) -> None:
|
||||
"""Fully remove an adapter, dropping all preserved state.
|
||||
|
||||
Called when an adapter is deleted from the database (not just disabled).
|
||||
"""
|
||||
state = self._adapter_states.pop(name, None)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Stop if running
|
||||
if state.is_running:
|
||||
state.cancel_event.set()
|
||||
if state.task:
|
||||
state.task.cancel()
|
||||
try:
|
||||
await state.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await state.adapter.shutdown()
|
||||
|
||||
logger.info(
|
||||
"Adapter removed",
|
||||
extra={"adapter": name},
|
||||
)
|
||||
|
||||
async def _reschedule_adapter(
|
||||
self,
|
||||
name: str,
|
||||
new_config: AdapterConfig,
|
||||
) -> None:
|
||||
"""Reschedule an adapter with new configuration.
|
||||
|
||||
Maintains rate-limit guarantee: next poll at
|
||||
(last_completed_poll + new_cadence_s), not now + new_cadence_s.
|
||||
"""
|
||||
state = self._adapter_states.get(name)
|
||||
if state is None:
|
||||
# Adapter not running - just start it
|
||||
await self._start_adapter(new_config)
|
||||
return
|
||||
|
||||
if not state.is_running:
|
||||
# Adapter stopped - restart it
|
||||
await self._start_adapter(new_config)
|
||||
return
|
||||
|
||||
old_cadence = state.config.cadence_s
|
||||
new_cadence = new_config.cadence_s
|
||||
|
||||
# Update config
|
||||
state.config = new_config
|
||||
|
||||
# Update adapter's cadence
|
||||
state.adapter.cadence_s = new_cadence
|
||||
|
||||
# Update adapter settings if needed (e.g., states list)
|
||||
if name == "nws":
|
||||
nws_config = self._adapter_config_to_nws_config(new_config)
|
||||
state.adapter.states = set(s.upper() for s in nws_config.states)
|
||||
|
||||
# Calculate next poll time for logging
|
||||
if state.last_completed_poll:
|
||||
next_poll_at = datetime.fromtimestamp(
|
||||
state.last_completed_poll.timestamp() + new_cadence,
|
||||
tz=timezone.utc,
|
||||
)
|
||||
else:
|
||||
next_poll_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.info(
|
||||
"Rescheduled adapter",
|
||||
extra={
|
||||
"adapter": name,
|
||||
"old_cadence_s": old_cadence,
|
||||
"new_cadence_s": new_cadence,
|
||||
"next_poll": next_poll_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Signal the loop to re-evaluate its schedule
|
||||
state.cancel_event.set()
|
||||
|
||||
async def _on_config_change(self, table: str, key: str) -> None:
|
||||
"""Handle a configuration change notification.
|
||||
|
||||
Called when NOTIFY fires for config changes.
|
||||
"""
|
||||
if table != "adapters":
|
||||
return
|
||||
|
||||
adapter_name = key
|
||||
logger.info(
|
||||
"Config change received",
|
||||
extra={"table": table, "key": key},
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Fetch the current config for this adapter
|
||||
new_config = await self._config_source.get_adapter(adapter_name)
|
||||
current_state = self._adapter_states.get(adapter_name)
|
||||
|
||||
if new_config is None:
|
||||
# Adapter was deleted - fully remove, don't just stop
|
||||
if current_state:
|
||||
await self._remove_adapter(adapter_name)
|
||||
logger.info(
|
||||
"Adapter deleted, removed",
|
||||
extra={"adapter": adapter_name},
|
||||
)
|
||||
return
|
||||
|
||||
if not new_config.enabled or new_config.is_paused:
|
||||
# Adapter disabled or paused - stop but preserve state
|
||||
if current_state and current_state.is_running:
|
||||
await self._stop_adapter(adapter_name)
|
||||
logger.info(
|
||||
"Adapter disabled/paused, stopped",
|
||||
extra={
|
||||
"adapter": adapter_name,
|
||||
"enabled": new_config.enabled,
|
||||
"paused": new_config.is_paused,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if current_state is None or not current_state.is_running:
|
||||
# Adapter was enabled or created - start (will reuse state if exists)
|
||||
await self._start_adapter(new_config)
|
||||
logger.info(
|
||||
"Adapter enabled, started",
|
||||
extra={"adapter": adapter_name},
|
||||
)
|
||||
else:
|
||||
# Adapter config changed (cadence, settings)
|
||||
await self._reschedule_adapter(adapter_name, new_config)
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
"""Publish periodic heartbeats."""
|
||||
|
|
@ -181,32 +523,38 @@ class Supervisor:
|
|||
"""Start the supervisor."""
|
||||
await self.connect()
|
||||
|
||||
# Initialize adapters
|
||||
if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled:
|
||||
adapter = NWSAdapter(
|
||||
config=self.config.adapters["nws"],
|
||||
cursor_db_path=CURSOR_DB_PATH,
|
||||
)
|
||||
await adapter.startup()
|
||||
self._adapters.append(adapter)
|
||||
logger.info("NWS adapter initialized")
|
||||
# Load and start enabled adapters
|
||||
enabled_adapters = await self._config_source.list_enabled_adapters()
|
||||
for config in enabled_adapters:
|
||||
await self._start_adapter(config)
|
||||
|
||||
# Start adapter tasks
|
||||
for adapter in self._adapters:
|
||||
task = asyncio.create_task(self._run_adapter(adapter))
|
||||
self._tasks.append(task)
|
||||
# Start config watcher (for DB source, this runs forever; for TOML, returns immediately)
|
||||
self._config_watch_task = asyncio.create_task(
|
||||
self._config_source.watch_for_changes(self._on_config_change)
|
||||
)
|
||||
|
||||
# Start heartbeat
|
||||
self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
|
||||
|
||||
logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]})
|
||||
logger.info(
|
||||
"Supervisor started",
|
||||
extra={"adapters": list(self._adapter_states.keys())},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the supervisor gracefully."""
|
||||
logger.info("Supervisor shutting down")
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Cancel tasks
|
||||
# Cancel config watcher
|
||||
if self._config_watch_task:
|
||||
self._config_watch_task.cancel()
|
||||
try:
|
||||
await self._config_watch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel heartbeat and other tasks
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
|
|
@ -214,9 +562,12 @@ class Supervisor:
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shutdown adapters
|
||||
for adapter in self._adapters:
|
||||
await adapter.shutdown()
|
||||
# Remove all adapters (full cleanup)
|
||||
for name in list(self._adapter_states.keys()):
|
||||
await self._remove_adapter(name)
|
||||
|
||||
# Close config source
|
||||
await self._config_source.close()
|
||||
|
||||
await self.disconnect()
|
||||
logger.info("Supervisor stopped")
|
||||
|
|
@ -226,8 +577,40 @@ async def async_main() -> None:
|
|||
"""Async entry point."""
|
||||
setup_logging()
|
||||
|
||||
config = load_config(CONFIG_PATH)
|
||||
supervisor = Supervisor(config)
|
||||
settings = get_settings()
|
||||
logger.info(
|
||||
"Config source: %s",
|
||||
settings.config_source,
|
||||
extra={"config_source": settings.config_source},
|
||||
)
|
||||
|
||||
# Create config source based on setting
|
||||
config_source = await create_config_source(
|
||||
source_type=settings.config_source,
|
||||
dsn=settings.db_dsn,
|
||||
toml_path=settings.config_toml_path,
|
||||
)
|
||||
|
||||
# CloudEvents config: try TOML first, fall back to code defaults
|
||||
# (CloudEvents envelope format is protocol-level, not operator-configurable)
|
||||
cloudevents_config = None
|
||||
if settings.config_source == "toml":
|
||||
try:
|
||||
toml_config = load_config(str(settings.config_toml_path))
|
||||
cloudevents_config = toml_config
|
||||
except Exception:
|
||||
pass # Will use defaults from cloudevents_constants
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url=settings.nats_url,
|
||||
cloudevents_config=cloudevents_config,
|
||||
)
|
||||
logger.info(
|
||||
"CloudEvents config: %s",
|
||||
"TOML" if cloudevents_config else "defaults",
|
||||
extra={"cloudevents_source": "toml" if cloudevents_config else "defaults"},
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
|
|
|||
41
systemd/README.md
Normal file
41
systemd/README.md
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# Systemd Unit Files
|
||||
|
||||
These unit files configure Central services for systemd.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Copy unit files
|
||||
sudo cp central-supervisor.service /etc/systemd/system/
|
||||
sudo cp central-archive.service /etc/systemd/system/
|
||||
|
||||
# Reload systemd
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
# Enable and start services
|
||||
sudo systemctl enable --now central-supervisor
|
||||
sudo systemctl enable --now central-archive
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Both services load environment variables from `/etc/central/central.env`:
|
||||
|
||||
```bash
|
||||
CENTRAL_DB_DSN=postgresql://central:password@localhost/central
|
||||
CENTRAL_NATS_URL=nats://localhost:4222
|
||||
CENTRAL_CONFIG_SOURCE=db
|
||||
CENTRAL_MASTER_KEY_PATH=/etc/central/master.key
|
||||
```
|
||||
|
||||
## Service Dependencies
|
||||
|
||||
- **central-supervisor**: Requires NATS server
|
||||
- **central-archive**: Requires NATS server and PostgreSQL
|
||||
|
||||
## Logs
|
||||
|
||||
```bash
|
||||
journalctl -u central-supervisor -f
|
||||
journalctl -u central-archive -f
|
||||
```
|
||||
|
|
@ -10,6 +10,7 @@ User=central
|
|||
Group=central
|
||||
WorkingDirectory=/opt/central
|
||||
Environment=HOME=/opt/central
|
||||
EnvironmentFile=/etc/central/central.env
|
||||
ExecStart=/opt/central/.venv/bin/central-archive
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ User=central
|
|||
Group=central
|
||||
WorkingDirectory=/opt/central
|
||||
Environment=HOME=/opt/central
|
||||
EnvironmentFile=/etc/central/central.env
|
||||
ExecStart=/opt/central/.venv/bin/central-supervisor
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
|
|
|||
285
tests/test_config_source.py
Normal file
285
tests/test_config_source.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for configuration source abstraction."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_source import (
|
||||
ConfigSource,
|
||||
TomlConfigSource,
|
||||
DbConfigSource,
|
||||
create_config_source,
|
||||
)
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
class TestTomlConfigSource:
|
||||
"""Tests for TOML-based config source."""
|
||||
|
||||
@pytest.fixture
|
||||
def toml_file(self, tmp_path: Path) -> Path:
|
||||
"""Create a test TOML config file."""
|
||||
toml_content = """
|
||||
[adapters.nws]
|
||||
enabled = true
|
||||
cadence_s = 60
|
||||
states = ["ID", "MT"]
|
||||
contact_email = "test@example.com"
|
||||
|
||||
[adapters.disabled_adapter]
|
||||
enabled = false
|
||||
cadence_s = 300
|
||||
states = []
|
||||
contact_email = "test@example.com"
|
||||
|
||||
[cloudevents]
|
||||
type_prefix = "central"
|
||||
source = "central.local"
|
||||
schema_version = "1.0"
|
||||
|
||||
[nats]
|
||||
url = "nats://localhost:4222"
|
||||
|
||||
[postgres]
|
||||
dsn = "postgresql://user:pass@localhost/db"
|
||||
"""
|
||||
path = tmp_path / "central.toml"
|
||||
path.write_text(toml_content)
|
||||
return path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters(self, toml_file: Path) -> None:
|
||||
"""list_enabled_adapters returns only enabled adapters."""
|
||||
source = TomlConfigSource(toml_file)
|
||||
adapters = await source.list_enabled_adapters()
|
||||
|
||||
assert len(adapters) == 1
|
||||
assert adapters[0].name == "nws"
|
||||
assert adapters[0].enabled is True
|
||||
assert adapters[0].cadence_s == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_adapter(self, toml_file: Path) -> None:
|
||||
"""get_adapter returns correct adapter config."""
|
||||
source = TomlConfigSource(toml_file)
|
||||
|
||||
adapter = await source.get_adapter("nws")
|
||||
assert adapter is not None
|
||||
assert adapter.name == "nws"
|
||||
assert adapter.settings["states"] == ["ID", "MT"]
|
||||
assert adapter.settings["contact_email"] == "test@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_adapter(self, toml_file: Path) -> None:
|
||||
"""get_adapter returns None for nonexistent adapter."""
|
||||
source = TomlConfigSource(toml_file)
|
||||
adapter = await source.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_for_changes_returns_immediately(self, toml_file: Path) -> None:
|
||||
"""watch_for_changes is a no-op for TOML source."""
|
||||
source = TomlConfigSource(toml_file)
|
||||
callback_called = False
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
|
||||
# Should return immediately without blocking
|
||||
await asyncio.wait_for(
|
||||
source.watch_for_changes(callback),
|
||||
timeout=1.0,
|
||||
)
|
||||
assert not callback_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implements_protocol(self, toml_file: Path) -> None:
|
||||
"""TomlConfigSource implements ConfigSource protocol."""
|
||||
source = TomlConfigSource(toml_file)
|
||||
assert isinstance(source, ConfigSource)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
class TestDbConfigSource:
|
||||
"""Tests for database-backed config source."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
|
||||
"""Create a DbConfigSource for testing."""
|
||||
source = await DbConfigSource.create(TEST_DB_DSN)
|
||||
yield source
|
||||
await source.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
|
||||
"""list_enabled_adapters returns empty list when no adapters."""
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
assert adapters == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""list_enabled_adapters returns only enabled, non-paused adapters."""
|
||||
# Insert test adapters
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES
|
||||
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
|
||||
('disabled_adapter', false, 60, '{}'::jsonb),
|
||||
('paused_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now()
|
||||
WHERE name = 'paused_adapter'
|
||||
""")
|
||||
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
|
||||
assert len(adapters) == 1
|
||||
assert adapters[0].name == "enabled_adapter"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_adapter(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""get_adapter returns correct adapter config."""
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
|
||||
""")
|
||||
|
||||
adapter = await db_source.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"states": ["ID"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
|
||||
"""get_adapter returns None for nonexistent adapter."""
|
||||
adapter = await db_source.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
|
||||
"""DbConfigSource implements ConfigSource protocol."""
|
||||
assert isinstance(db_source, ConfigSource)
|
||||
|
||||
|
||||
class TestCreateConfigSource:
|
||||
"""Tests for the config source factory function."""
|
||||
|
||||
@pytest.fixture
|
||||
def toml_file(self, tmp_path: Path) -> Path:
|
||||
"""Create a minimal TOML config file."""
|
||||
toml_content = """
|
||||
[adapters.nws]
|
||||
enabled = true
|
||||
cadence_s = 60
|
||||
states = []
|
||||
contact_email = "test@example.com"
|
||||
|
||||
[cloudevents]
|
||||
[nats]
|
||||
[postgres]
|
||||
dsn = "postgresql://test@localhost/test"
|
||||
"""
|
||||
path = tmp_path / "central.toml"
|
||||
path.write_text(toml_content)
|
||||
return path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_toml_source(self, toml_file: Path) -> None:
|
||||
"""create_config_source returns TomlConfigSource for 'toml' type."""
|
||||
source = await create_config_source(
|
||||
source_type="toml",
|
||||
toml_path=toml_file,
|
||||
)
|
||||
assert isinstance(source, TomlConfigSource)
|
||||
await source.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_db_source(self, clean_config_schema: None) -> None:
|
||||
"""create_config_source returns DbConfigSource for 'db' type."""
|
||||
source = await create_config_source(
|
||||
source_type="db",
|
||||
dsn=TEST_DB_DSN,
|
||||
)
|
||||
assert isinstance(source, DbConfigSource)
|
||||
await source.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_toml_requires_path(self) -> None:
|
||||
"""create_config_source raises for 'toml' without path."""
|
||||
with pytest.raises(ValueError, match="toml_path required"):
|
||||
await create_config_source(source_type="toml")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_db_requires_dsn(self) -> None:
|
||||
"""create_config_source raises for 'db' without dsn."""
|
||||
with pytest.raises(ValueError, match="dsn required"):
|
||||
await create_config_source(source_type="db")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unknown_type_raises(self) -> None:
|
||||
"""create_config_source raises for unknown type."""
|
||||
with pytest.raises(ValueError, match="Unknown config source type"):
|
||||
await create_config_source(source_type="unknown")
|
||||
394
tests/test_supervisor_hotreload.py
Normal file
394
tests/test_supervisor_hotreload.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
"""Tests for supervisor hot-reload and rate-limiting behavior."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_models import AdapterConfig
|
||||
from central.config_source import DbConfigSource
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
# Create notify trigger
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestDbConfigSourceNotifications:
|
||||
"""Tests for DbConfigSource NOTIFY integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_receives_notifications(
|
||||
self,
|
||||
config_store: ConfigStore,
|
||||
db_conn: asyncpg.Connection,
|
||||
) -> None:
|
||||
"""watch_for_changes receives NOTIFY when adapter changes."""
|
||||
source = DbConfigSource(config_store)
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start watching in background
|
||||
watch_task = asyncio.create_task(source.watch_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Wait for listener to connect
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Insert an adapter via direct connection (not through store)
|
||||
# This triggers the NOTIFY
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
|
||||
# Wait for notification
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0] == ("adapters", "test_adapter")
|
||||
|
||||
finally:
|
||||
watch_task.cancel()
|
||||
try:
|
||||
await watch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestRateLimitGuarantee:
|
||||
"""Tests for rate-limit guarantees during hot-reload.
|
||||
|
||||
These tests verify the critical invariant: cadence changes must not
|
||||
cause extra API calls before (last_poll + new_cadence).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_change_respects_last_poll_time(self) -> None:
|
||||
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
|
||||
|
||||
This is the core rate-limit guarantee test (gate 3).
|
||||
"""
|
||||
# Import supervisor module to access AdapterState
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
# Mock adapter
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Create adapter state with a known last_completed_poll time
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60, # Original cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate cadence change to 90 seconds
|
||||
new_config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=90, # New cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Update state as reschedule would
|
||||
state.config = new_config
|
||||
state.adapter.cadence_s = 90
|
||||
|
||||
# Calculate expected next poll time
|
||||
expected_next_poll = last_poll + timedelta(seconds=90)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_wait = max(0, (expected_next_poll - now).total_seconds())
|
||||
|
||||
# The wait time should be based on last_poll + new_cadence
|
||||
# Since last_poll was 30 seconds ago and new cadence is 90,
|
||||
# we should wait 60 more seconds (90 - 30 = 60)
|
||||
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
|
||||
actual_wait = max(0, actual_next_poll - now.timestamp())
|
||||
|
||||
# Allow 1 second tolerance for timing
|
||||
assert abs(actual_wait - 60) < 2, (
|
||||
f"Expected ~60s wait, got {actual_wait}s. "
|
||||
f"Rate limit violated: poll would happen before last_poll + new_cadence"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
|
||||
"""When last_poll + new_cadence is already past, poll immediately.
|
||||
|
||||
If operator increases cadence to 120s after a gap of 150s,
|
||||
the poll should happen now (not wait another 120s).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 150 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=120, # Increased cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
|
||||
assert wait_time < 1, (
|
||||
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
|
||||
f"After a gap exceeding new cadence, poll should happen now."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
|
||||
"""Re-enabling adapter schedules poll at last_poll + cadence.
|
||||
|
||||
If adapter was disabled for a while and then re-enabled, the next
|
||||
poll should be at (last_completed_poll + cadence_s), not immediately
|
||||
(unless that time has already passed).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 30 seconds ago, then adapter was disabled
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
# Re-enabled config
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~30 more seconds (60 - 30 = 30)
|
||||
assert abs(wait_time - 30) < 2, (
|
||||
f"Expected ~30s wait after re-enable, got {wait_time}s. "
|
||||
f"Rate limit violated on enable→disable→enable sequence."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
|
||||
"""Multiple rapid cadence changes don't cause extra polls.
|
||||
|
||||
If NOTIFY fires rapidly (60→90→120→90), the final schedule should
|
||||
still be based on last_completed_poll + final_cadence.
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 20 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
),
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate rapid cadence changes
|
||||
for cadence in [90, 120, 90]: # Final cadence is 90
|
||||
state.config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=cadence,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
state.adapter.cadence_s = cadence
|
||||
|
||||
# Final schedule should be last_poll + 90
|
||||
now = datetime.now(timezone.utc)
|
||||
final_cadence = 90
|
||||
next_poll_at = last_poll.timestamp() + final_cadence
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~70 seconds (90 - 20 = 70)
|
||||
assert abs(wait_time - 70) < 2, (
|
||||
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
|
||||
f"Multiple NOTIFYs should not cause extra polls."
|
||||
)
|
||||
|
||||
|
||||
class TestBootstrapConfigFlag:
|
||||
"""Tests for CENTRAL_CONFIG_SOURCE bootstrap flag."""
|
||||
|
||||
def test_default_is_toml(self) -> None:
|
||||
"""Default config_source is 'toml'."""
|
||||
from central.bootstrap_config import Settings
|
||||
|
||||
# Create settings with minimal required fields
|
||||
settings = Settings(
|
||||
db_dsn="postgresql://test@localhost/test",
|
||||
_env_file=None,
|
||||
)
|
||||
assert settings.config_source == "toml"
|
||||
|
||||
def test_accepts_db(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""config_source accepts 'db' value."""
|
||||
from central.bootstrap_config import Settings, get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
monkeypatch.setenv("CENTRAL_CONFIG_SOURCE", "db")
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test@localhost/test")
|
||||
|
||||
settings = get_settings()
|
||||
assert settings.config_source == "db"
|
||||
|
||||
def test_rejects_invalid(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""config_source rejects invalid values."""
|
||||
from pydantic import ValidationError
|
||||
from central.bootstrap_config import Settings, get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
monkeypatch.setenv("CENTRAL_CONFIG_SOURCE", "invalid")
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test@localhost/test")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_settings()
|
||||
546
tests/test_supervisor_integration.py
Normal file
546
tests/test_supervisor_integration.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
"""Integration tests for Supervisor hot-reload with enable/disable/enable flow.
|
||||
|
||||
These tests exercise the actual Supervisor._on_config_change code path,
|
||||
not just AdapterState math in isolation. They verify the rate-limit
|
||||
guarantee is maintained across adapter stop/start cycles.
|
||||
|
||||
IMPORTANT: These tests are designed to:
|
||||
- FAIL on unfixed code (Test B fails because last_completed_poll is lost)
|
||||
- PASS on fixed code (last_completed_poll is preserved across disable/enable)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_models import AdapterConfig
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
|
||||
def adapter_is_running(state) -> bool:
|
||||
"""Check if adapter is running (compatible with both fixed and unfixed code)."""
|
||||
# Fixed code has is_running property; unfixed checks task directly
|
||||
if hasattr(state, 'is_running'):
|
||||
return state.is_running
|
||||
return state.task is not None and not state.task.done()
|
||||
|
||||
|
||||
async def cleanup_adapter(supervisor, name: str) -> None:
|
||||
"""Clean up adapter (compatible with both fixed and unfixed code)."""
|
||||
# Fixed code has _remove_adapter; unfixed uses _stop_adapter which pops
|
||||
if hasattr(supervisor, '_remove_adapter'):
|
||||
await supervisor._remove_adapter(name)
|
||||
else:
|
||||
await supervisor._stop_adapter(name)
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
class MockConfigSource:
|
||||
"""Mock ConfigSource for testing Supervisor without DB."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._adapters: dict[str, AdapterConfig] = {}
|
||||
|
||||
def set_adapter(self, config: AdapterConfig | None, name: str | None = None) -> None:
|
||||
"""Set or remove an adapter config."""
|
||||
if config is None:
|
||||
if name:
|
||||
self._adapters.pop(name, None)
|
||||
else:
|
||||
self._adapters[config.name] = config
|
||||
|
||||
async def list_enabled_adapters(self) -> list[AdapterConfig]:
|
||||
return [a for a in self._adapters.values() if a.enabled and not a.is_paused]
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
return self._adapters.get(name)
|
||||
|
||||
async def watch_for_changes(self, callback) -> None:
|
||||
# No-op for testing
|
||||
return
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockNWSAdapter:
|
||||
"""Mock NWSAdapter that tracks poll calls and allows control."""
|
||||
|
||||
def __init__(self, config, cursor_db_path) -> None:
|
||||
self.config = config
|
||||
self.cadence_s = config.cadence_s
|
||||
self.states = set(s.upper() for s in config.states)
|
||||
self.poll_count = 0
|
||||
self.poll_times: list[datetime] = []
|
||||
self._shutdown = False
|
||||
|
||||
async def startup(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._shutdown = True
|
||||
|
||||
async def poll(self):
|
||||
"""Yield nothing - we just track that poll was called."""
|
||||
self.poll_count += 1
|
||||
self.poll_times.append(datetime.now(timezone.utc))
|
||||
return
|
||||
yield # Make this an async generator
|
||||
|
||||
def is_published(self, event_id: str) -> bool:
|
||||
return False
|
||||
|
||||
def mark_published(self, event_id: str) -> None:
|
||||
pass
|
||||
|
||||
def bump_last_seen(self, event_id: str) -> None:
|
||||
pass
|
||||
|
||||
def sweep_old_ids(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nats():
|
||||
"""Mock NATS connection."""
|
||||
mock_nc = AsyncMock()
|
||||
mock_nc.publish = AsyncMock()
|
||||
mock_js = AsyncMock()
|
||||
mock_js.publish = AsyncMock()
|
||||
mock_nc.jetstream.return_value = mock_js
|
||||
return mock_nc
|
||||
|
||||
|
||||
class TestEnableDisableEnableIntegration:
|
||||
"""Integration tests for enable→disable→enable flow through Supervisor.
|
||||
|
||||
These tests verify that _on_config_change → _stop_adapter → _start_adapter
|
||||
preserves last_completed_poll correctly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_enable_gap_longer_than_cadence(
|
||||
self, mock_nats, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test A: Re-enable after gap longer than cadence polls immediately.
|
||||
|
||||
- Start adapter (cadence 60s)
|
||||
- Simulate completed poll 5 minutes ago
|
||||
- Disable adapter
|
||||
- Re-enable adapter
|
||||
- Assert next poll fires immediately (last+cadence is in past)
|
||||
- Assert exactly ONE poll happens, not multiple catch-up
|
||||
"""
|
||||
from central.supervisor import Supervisor, AdapterState
|
||||
|
||||
config_source = MockConfigSource()
|
||||
initial_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(initial_config)
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url="nats://localhost:4222",
|
||||
cloudevents_config=None,
|
||||
)
|
||||
|
||||
# Mock NATS connection
|
||||
supervisor._nc = mock_nats
|
||||
supervisor._js = mock_nats.jetstream()
|
||||
|
||||
# Patch NWSAdapter to use our mock
|
||||
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
|
||||
# Start supervisor (starts adapter)
|
||||
await supervisor._start_adapter(initial_config)
|
||||
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
assert adapter_is_running(state)
|
||||
|
||||
# Simulate completed poll 5 minutes ago
|
||||
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
saved_last_poll = state.last_completed_poll
|
||||
|
||||
# Disable adapter
|
||||
disabled_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=False,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(disabled_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify stopped but state preserved (THIS IS THE KEY CHECK)
|
||||
# On unfixed code, state will be NONE because pop() removes it
|
||||
# On fixed code, state still exists with is_running=False
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None, (
|
||||
"State was removed on stop! This violates the rate-limit guarantee. "
|
||||
"State should be preserved to maintain last_completed_poll."
|
||||
)
|
||||
assert not adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_last_poll
|
||||
|
||||
# Re-enable adapter
|
||||
reenabled_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(reenabled_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify restarted with preserved last_completed_poll
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
assert adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_last_poll
|
||||
|
||||
# The loop should detect that last_poll + cadence is in the past
|
||||
# and poll immediately. Let's verify by checking the wait time logic.
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = saved_last_poll.timestamp() + 60 # cadence = 60s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# last_poll was 5 minutes ago, cadence is 60s
|
||||
# next_poll_at = 5_minutes_ago + 60s = 4_minutes_ago
|
||||
# wait_time should be 0 (poll immediately)
|
||||
assert wait_time == 0, (
|
||||
f"Expected immediate poll (wait=0), got wait={wait_time}s. "
|
||||
f"last_poll was {saved_last_poll}, now is {now}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
supervisor._shutdown_event.set()
|
||||
await cleanup_adapter(supervisor, "nws")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_enable_gap_shorter_than_cadence(
|
||||
self, mock_nats, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test B: Re-enable after gap shorter than cadence respects rate limit.
|
||||
|
||||
THIS IS THE KEY TEST that failed before the fix.
|
||||
|
||||
- Start adapter (cadence 60s)
|
||||
- Simulate completed poll 10 seconds ago
|
||||
- Disable adapter
|
||||
- Re-enable adapter 20 seconds later (still within cadence window)
|
||||
- Assert next poll fires at last_poll + 60s, NOT immediately
|
||||
"""
|
||||
from central.supervisor import Supervisor, AdapterState
|
||||
|
||||
config_source = MockConfigSource()
|
||||
initial_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(initial_config)
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url="nats://localhost:4222",
|
||||
cloudevents_config=None,
|
||||
)
|
||||
|
||||
supervisor._nc = mock_nats
|
||||
supervisor._js = mock_nats.jetstream()
|
||||
|
||||
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
|
||||
# Start adapter
|
||||
await supervisor._start_adapter(initial_config)
|
||||
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
|
||||
# Simulate completed poll 10 seconds ago
|
||||
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
saved_last_poll = state.last_completed_poll
|
||||
|
||||
# Disable adapter
|
||||
disabled_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=False,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(disabled_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify stopped but state preserved (THIS IS THE KEY CHECK)
|
||||
# On unfixed code, state will be NONE because pop() removes it
|
||||
# On fixed code, state still exists with is_running=False
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None, (
|
||||
"State was removed on stop! This violates the rate-limit guarantee. "
|
||||
"State should be preserved to maintain last_completed_poll."
|
||||
)
|
||||
assert not adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_last_poll
|
||||
|
||||
# Re-enable adapter (simulate 20 seconds later, but we're just
|
||||
# checking the rate limit logic)
|
||||
reenabled_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(reenabled_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify restarted with preserved last_completed_poll
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
assert adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_last_poll
|
||||
|
||||
# The loop should detect that last_poll + cadence is still in the future
|
||||
# and wait until then.
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = saved_last_poll.timestamp() + 60
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# last_poll was ~10 seconds ago, cadence is 60s
|
||||
# wait_time should be ~50s (60 - 10 = 50)
|
||||
assert 45 < wait_time < 55, (
|
||||
f"Expected ~50s wait (respecting rate limit), got wait={wait_time}s. "
|
||||
f"Rate limit violated: poll would happen before last_poll + cadence"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
supervisor._shutdown_event.set()
|
||||
await cleanup_adapter(supervisor, "nws")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_delete_readd_fresh_state(
|
||||
self, mock_nats, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test C: Delete then re-add clears preserved state.
|
||||
|
||||
- Start adapter
|
||||
- Simulate completed poll
|
||||
- Disable adapter
|
||||
- DELETE adapter from DB (not just disable)
|
||||
- Re-add adapter with same name
|
||||
- Assert preserved timestamp is dropped (fresh adapter, immediate poll)
|
||||
"""
|
||||
from central.supervisor import Supervisor
|
||||
|
||||
config_source = MockConfigSource()
|
||||
initial_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(initial_config)
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url="nats://localhost:4222",
|
||||
cloudevents_config=None,
|
||||
)
|
||||
|
||||
supervisor._nc = mock_nats
|
||||
supervisor._js = mock_nats.jetstream()
|
||||
|
||||
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
|
||||
# Start adapter
|
||||
await supervisor._start_adapter(initial_config)
|
||||
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
|
||||
# Simulate completed poll 10 seconds ago
|
||||
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
|
||||
# Disable adapter
|
||||
disabled_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=False,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(disabled_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# DELETE adapter from DB (remove from config source)
|
||||
config_source.set_adapter(None, name="nws")
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify adapter fully removed
|
||||
assert "nws" not in supervisor._adapter_states
|
||||
|
||||
# Re-add adapter with same name
|
||||
new_config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(new_config)
|
||||
await supervisor._on_config_change("adapters", "nws")
|
||||
|
||||
# Verify new adapter started fresh
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert state is not None
|
||||
assert adapter_is_running(state)
|
||||
# last_completed_poll should be None (fresh adapter)
|
||||
assert state.last_completed_poll is None, (
|
||||
f"Expected None (fresh adapter), got {state.last_completed_poll}. "
|
||||
f"Preserved state not cleared on delete."
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
supervisor._shutdown_event.set()
|
||||
await cleanup_adapter(supervisor, "nws")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_preserves_state_start_reuses_it(
|
||||
self, mock_nats, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify _stop_adapter preserves state and _start_adapter reuses it."""
|
||||
from central.supervisor import Supervisor
|
||||
|
||||
config_source = MockConfigSource()
|
||||
config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(config)
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url="nats://localhost:4222",
|
||||
cloudevents_config=None,
|
||||
)
|
||||
|
||||
supervisor._nc = mock_nats
|
||||
supervisor._js = mock_nats.jetstream()
|
||||
|
||||
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
|
||||
# Start adapter
|
||||
await supervisor._start_adapter(config)
|
||||
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
saved_poll = state.last_completed_poll
|
||||
|
||||
# Stop adapter
|
||||
await supervisor._stop_adapter("nws")
|
||||
|
||||
# State should still exist
|
||||
assert "nws" in supervisor._adapter_states
|
||||
state = supervisor._adapter_states["nws"]
|
||||
assert not adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_poll
|
||||
|
||||
# Restart adapter
|
||||
await supervisor._start_adapter(config)
|
||||
|
||||
# Should reuse existing state
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
assert adapter_is_running(state)
|
||||
assert state.last_completed_poll == saved_poll
|
||||
|
||||
# Cleanup
|
||||
supervisor._shutdown_event.set()
|
||||
await cleanup_adapter(supervisor, "nws")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_adapter_clears_state(
|
||||
self, mock_nats, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify _remove_adapter fully clears state."""
|
||||
from central.supervisor import Supervisor
|
||||
|
||||
config_source = MockConfigSource()
|
||||
config = AdapterConfig(
|
||||
name="nws",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={"states": ["ID"], "contact_email": "test@test.com"},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
config_source.set_adapter(config)
|
||||
|
||||
supervisor = Supervisor(
|
||||
config_source=config_source,
|
||||
nats_url="nats://localhost:4222",
|
||||
cloudevents_config=None,
|
||||
)
|
||||
|
||||
supervisor._nc = mock_nats
|
||||
supervisor._js = mock_nats.jetstream()
|
||||
|
||||
with patch("central.supervisor.NWSAdapter", MockNWSAdapter):
|
||||
await supervisor._start_adapter(config)
|
||||
|
||||
state = supervisor._adapter_states.get("nws")
|
||||
state.last_completed_poll = datetime.now(timezone.utc)
|
||||
|
||||
# Remove adapter
|
||||
await cleanup_adapter(supervisor, "nws")
|
||||
|
||||
# State should be gone
|
||||
assert "nws" not in supervisor._adapter_states
|
||||
Loading…
Add table
Add a link
Reference in a new issue