Merge pull request #1 from zvx-echo6/feature/1a-config-storage

feat(config): Phase 1a-2 config storage primitives
This commit is contained in:
malice 2026-05-15 19:49:21 -06:00 committed by GitHub
commit ee081c9bc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1570 additions and 161 deletions

View file

@ -15,14 +15,18 @@ dependencies = [
"aiolimiter>=1.2.1", "aiolimiter>=1.2.1",
"asyncpg>=0.31.0", "asyncpg>=0.31.0",
"cloudevents>=2.0.0", "cloudevents>=2.0.0",
"cryptography>=44.0.0",
"nats-py>=2.14.0", "nats-py>=2.14.0",
"pydantic>=2,<3", "pydantic>=2,<3",
"pydantic-settings>=2.7.0",
"tenacity>=9.1.4", "tenacity>=9.1.4",
] ]
[project.scripts] [project.scripts]
central-supervisor = "central.supervisor:main" central-supervisor = "central.supervisor:main"
central-archive = "central.archive:main" central-archive = "central.archive:main"
central-migrate = "central.migrate:main"
central-cli = "central.cli:main"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/central"] packages = ["src/central"]

View file

@ -0,0 +1,64 @@
-- Migration: 001_create_config_schema
-- Creates the config schema with adapters and api_keys tables.
-- Also seeds the NWS adapter row from current TOML config.
-- Create config schema
CREATE SCHEMA config;
-- Adapters configuration table
CREATE TABLE config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- API keys table (encrypted values)
CREATE TABLE config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
);
-- Notify function for config changes
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
-- Handle different table structures
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql;
-- Trigger for adapters table
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
-- Trigger for api_keys table
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
-- Seed NWS adapter from current TOML config values
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES (
'nws',
true,
60,
'{"states": ["ID", "OR", "WA", "MT", "WY", "UT", "NV"], "contact_email": "mj@k7zvx.com"}'::jsonb
);

View file

@ -0,0 +1,21 @@
-- Migration: 002_add_updated_at_trigger_and_index
-- Adds auto-update trigger for updated_at column on adapters table
-- Adds partial index for efficient enabled adapter queries
-- Auto-update trigger for updated_at
CREATE OR REPLACE FUNCTION config.set_updated_at()
RETURNS trigger AS $$
BEGIN
NEW.updated_at := now();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER adapters_set_updated_at
BEFORE UPDATE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.set_updated_at();
-- Partial index for enabled adapters (common query pattern)
CREATE INDEX adapters_enabled_idx
ON config.adapters (enabled)
WHERE enabled = true;

View file

@ -0,0 +1,46 @@
"""Bootstrap configuration from environment variables.
This module provides early-stage configuration loading from environment
variables or a .env file. Used before the database-backed config store
is available.
"""
from functools import lru_cache
from pathlib import Path
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Bootstrap settings loaded from environment or .env file."""
model_config = SettingsConfigDict(
env_prefix="CENTRAL_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
db_dsn: str = Field(description="PostgreSQL connection string")
nats_url: str = Field(default="nats://localhost:4222", description="NATS server URL")
master_key_path: Path = Field(
default=Path("/etc/central/master.key"),
description="Path to AES-256 master key file",
)
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
default="INFO",
description="Logging level",
)
@lru_cache
def get_settings(env_file: Path | None = None) -> Settings:
"""Load settings, optionally from a specific .env file.
Results are cached. Call get_settings.cache_clear() to reload.
"""
if env_file is not None:
return Settings(_env_file=env_file)
return Settings()

75
src/central/cli.py Normal file
View file

@ -0,0 +1,75 @@
"""Central CLI commands."""
import argparse
import asyncio
import sys
async def config_store_check() -> int:
"""Smoke test for config store connectivity.
Connects via bootstrap_config, lists adapters, and verifies crypto.
Returns 0 on success, 1 on failure.
"""
from central.bootstrap_config import get_settings
from central.config_store import ConfigStore
from central.crypto import decrypt, encrypt
settings = get_settings()
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
try:
store = await ConfigStore.create(settings.db_dsn)
except Exception as e:
print(f"ERROR: Failed to connect to database: {e}")
return 1
try:
# List adapters
adapters = await store.list_adapters()
print(f"\nAdapters ({len(adapters)}):")
for adapter in adapters:
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
print(f" settings: {adapter.settings}")
# Test crypto
test_plaintext = b"config_store_check_test"
try:
ciphertext = encrypt(test_plaintext)
decrypted = decrypt(ciphertext)
if decrypted == test_plaintext:
print("\ncrypto: ok")
else:
print("\ncrypto: FAILED (round-trip mismatch)")
return 1
except Exception as e:
print(f"\ncrypto: FAILED ({e})")
return 1
print("\nAll checks passed.")
return 0
finally:
await store.close()
def main_config_store_check() -> None:
"""Entry point for central-cli config-store-check."""
sys.exit(asyncio.run(config_store_check()))
def main() -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="Central CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("config-store-check", help="Test config store connectivity")
args = parser.parse_args()
if args.command == "config-store-check":
main_config_store_check()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,39 @@
"""Pydantic models for database-backed configuration."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AdapterConfig(BaseModel):
"""Configuration for a single adapter."""
name: str = Field(description="Unique adapter identifier")
enabled: bool = Field(default=True, description="Whether adapter is active")
cadence_s: int = Field(description="Poll interval in seconds")
settings: dict[str, Any] = Field(
default_factory=dict, description="Adapter-specific settings"
)
paused_at: datetime | None = Field(
default=None, description="When adapter was paused, if paused"
)
updated_at: datetime = Field(description="Last configuration update time")
@property
def is_paused(self) -> bool:
"""Check if adapter is currently paused."""
return self.paused_at is not None
class ApiKeyInfo(BaseModel):
"""Metadata about an API key (without the decrypted value)."""
alias: str = Field(description="Key identifier/alias")
created_at: datetime = Field(description="When key was created")
rotated_at: datetime | None = Field(
default=None, description="Last rotation time"
)
last_used_at: datetime | None = Field(
default=None, description="Last usage time"
)

269
src/central/config_store.py Normal file
View file

@ -0,0 +1,269 @@
"""Database-backed configuration store.
Provides async access to the config schema tables with support for
Postgres LISTEN/NOTIFY for real-time config change notifications.
"""
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
import asyncpg
from central.config_models import AdapterConfig
from central.crypto import decrypt, encrypt
logger = logging.getLogger(__name__)
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
"""Set up JSON codec for asyncpg connection."""
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
class ConfigStore:
"""Async interface to the config schema in Postgres."""
def __init__(self, pool: asyncpg.Pool) -> None:
self._pool = pool
@classmethod
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
"""Create a ConfigStore with a new connection pool."""
pool = await asyncpg.create_pool(
dsn,
min_size=min_size,
max_size=max_size,
init=_setup_json_codec,
)
return cls(pool)
async def close(self) -> None:
"""Close the connection pool."""
await self._pool.close()
# -------------------------------------------------------------------------
# Adapter configuration
# -------------------------------------------------------------------------
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get configuration for a specific adapter."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
WHERE name = $1
""",
name,
)
if row is None:
return None
return AdapterConfig(**dict(row))
async def list_adapters(self) -> list[AdapterConfig]:
"""List all configured adapters."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
ORDER BY name
"""
)
return [AdapterConfig(**dict(row)) for row in rows]
async def upsert_adapter(
self,
name: str,
enabled: bool,
cadence_s: int,
settings: dict[str, Any],
) -> None:
"""Insert or update an adapter configuration."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (name) DO UPDATE SET
enabled = EXCLUDED.enabled,
cadence_s = EXCLUDED.cadence_s,
settings = EXCLUDED.settings,
updated_at = now()
""",
name,
enabled,
cadence_s,
settings, # Will be encoded as JSON by the codec
)
async def pause_adapter(self, name: str) -> None:
"""Pause an adapter by setting paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = now(), updated_at = now()
WHERE name = $1
""",
name,
)
async def unpause_adapter(self, name: str) -> None:
"""Unpause an adapter by clearing paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = NULL, updated_at = now()
WHERE name = $1
""",
name,
)
# -------------------------------------------------------------------------
# API key management
# -------------------------------------------------------------------------
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
"""Store an API key, encrypting it with the master key."""
encrypted = encrypt(plaintext_value.encode("utf-8"))
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.api_keys (alias, encrypted_value)
VALUES ($1, $2)
ON CONFLICT (alias) DO UPDATE SET
encrypted_value = EXCLUDED.encrypted_value,
rotated_at = now()
""",
alias,
encrypted,
)
async def get_api_key(self, alias: str) -> str | None:
"""Retrieve and decrypt an API key by alias."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
""",
alias,
)
if row is not None:
# Update last_used_at
await conn.execute(
"""
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
""",
alias,
)
if row is None:
return None
return decrypt(row["encrypted_value"]).decode("utf-8")
async def delete_api_key(self, alias: str) -> bool:
"""Delete an API key. Returns True if key existed."""
async with self._pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM config.api_keys WHERE alias = $1", alias
)
return result == "DELETE 1"
# -------------------------------------------------------------------------
# Change notifications
# -------------------------------------------------------------------------
async def listen_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Listen for config changes via Postgres NOTIFY.
Runs forever, calling callback(table, key) each time a change is
detected. The callback can be sync or async.
On connection loss, automatically reconnects with exponential backoff.
Cancellation (via task.cancel()) propagates cleanly.
Args:
callback: Function called with (table_name, row_key) on each change.
"""
backoff = 1.0
max_backoff = 30.0
while True:
conn = None
try:
conn = await self._pool.acquire()
logger.info("Config listener connected to database")
backoff = 1.0 # Reset backoff on successful connect
def notification_handler(
conn: asyncpg.Connection,
pid: int,
channel: str,
payload: str,
) -> None:
# payload format: "table_name:key"
if ":" in payload:
table, key = payload.split(":", 1)
else:
table, key = payload, ""
result = callback(table, key)
if asyncio.iscoroutine(result):
asyncio.create_task(result)
await conn.add_listener("config_changed", notification_handler)
try:
# Keep connection alive with periodic keepalive
while True:
await asyncio.sleep(60)
await conn.execute("SELECT 1")
finally:
await conn.remove_listener("config_changed", notification_handler)
except asyncio.CancelledError:
# Cancellation must propagate cleanly
logger.info("Config listener cancelled")
raise
except (
asyncpg.PostgresConnectionError,
asyncpg.InterfaceError,
ConnectionResetError,
OSError,
) as e:
logger.warning(
"Config listener connection lost, reconnecting in %.1fs: %s",
backoff,
e,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
except Exception as e:
# Unexpected error - log and retry with backoff
logger.exception(
"Config listener unexpected error, reconnecting in %.1fs",
backoff,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
finally:
if conn is not None:
try:
await self._pool.release(conn)
except Exception:
pass # Connection may already be invalid

111
src/central/crypto.py Normal file
View file

@ -0,0 +1,111 @@
"""Cryptographic primitives for secret storage.
Uses AES-256-GCM for authenticated encryption. The master key is read
from the path specified in bootstrap config on first use and cached.
"""
import base64
import os
from functools import lru_cache
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# AES-256 requires 32-byte key
KEY_SIZE = 32
# GCM nonce size (96 bits recommended by NIST)
NONCE_SIZE = 12
class CryptoError(Exception):
"""Base exception for crypto operations."""
class KeyLoadError(CryptoError):
"""Failed to load master key."""
class DecryptionError(CryptoError):
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
@lru_cache
def _load_master_key(path: Path) -> bytes:
"""Load and decode the base64-encoded master key from file."""
try:
key_b64 = path.read_text().strip()
key = base64.b64decode(key_b64)
except FileNotFoundError:
raise KeyLoadError(f"Master key file not found: {path}")
except Exception as e:
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
if len(key) != KEY_SIZE:
raise KeyLoadError(
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
)
return key
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
"""Encrypt plaintext using AES-256-GCM.
Args:
plaintext: Data to encrypt.
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
key = _load_master_key(key_path)
nonce = os.urandom(NONCE_SIZE)
aesgcm = AESGCM(key)
# GCM appends the 16-byte tag to the ciphertext
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
return nonce + ciphertext_with_tag
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
"""Decrypt ciphertext using AES-256-GCM.
Args:
ciphertext: Data in format: nonce || ciphertext || tag
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Decrypted plaintext.
Raises:
DecryptionError: If decryption fails (wrong key or tampered data).
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
raise DecryptionError("Ciphertext too short")
key = _load_master_key(key_path)
nonce = ciphertext[:NONCE_SIZE]
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
aesgcm = AESGCM(key)
try:
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
return plaintext
def clear_key_cache() -> None:
"""Clear the cached master key. Use after key rotation."""
_load_master_key.cache_clear()

125
src/central/migrate.py Normal file
View file

@ -0,0 +1,125 @@
"""Simple database migration runner.
Tracks applied migrations in a `schema_migrations` table. Migrations are
plain SQL files in `sql/migrations/` named with numeric prefixes:
001_create_config_schema.sql
002_add_operators_table.sql
...
Usage:
central-migrate [--dry-run]
"""
import argparse
import asyncio
import sys
from pathlib import Path
import asyncpg
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
"""Create the schema_migrations table if it doesn't exist."""
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
"""Return set of already-applied migration versions."""
rows = await conn.fetch("SELECT version FROM schema_migrations")
return {row["version"] for row in rows}
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
"""Find all .sql files in migrations directory, sorted by name.
Returns list of (version, path) tuples where version is the filename
without extension.
"""
if not migrations_dir.exists():
return []
migrations = []
for f in sorted(migrations_dir.glob("*.sql")):
version = f.stem # e.g., "001_create_config_schema"
migrations.append((version, f))
return migrations
async def apply_migration(
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
) -> None:
"""Apply a single migration."""
sql = sql_path.read_text()
if dry_run:
print(f"[DRY RUN] Would apply: {version}")
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
return
async with conn.transaction():
await conn.execute(sql)
await conn.execute(
"INSERT INTO schema_migrations (version) VALUES ($1)", version
)
print(f"Applied: {version}")
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
"""Run all pending migrations.
Returns number of migrations applied.
"""
conn = await asyncpg.connect(dsn)
try:
await ensure_migrations_table(conn)
applied = await get_applied_migrations(conn)
pending = [
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
]
if not pending:
print("No pending migrations.")
return 0
print(f"Found {len(pending)} pending migration(s).")
for version, path in pending:
await apply_migration(conn, version, path, dry_run)
return len(pending)
finally:
await conn.close()
async def async_main() -> None:
"""Async entry point."""
parser = argparse.ArgumentParser(description="Run database migrations")
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be applied without executing",
)
args = parser.parse_args()
from central.bootstrap_config import get_settings
settings = get_settings()
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
if count > 0 and not args.dry_run:
print(f"Successfully applied {count} migration(s).")
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

18
tests/README.md Normal file
View file

@ -0,0 +1,18 @@
# Central Tests
## Test Database
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
By default, tests connect to:
```
postgresql://central_test:testpass@localhost/central_test
```
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
environment variable:
```bash
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
uv run pytest tests/test_config_store.py
```

View file

@ -0,0 +1,123 @@
"""Tests for bootstrap configuration."""
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from central.bootstrap_config import Settings, get_settings
class TestSettingsFromEnv:
"""Test loading settings from environment variables."""
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Settings are read from CENTRAL_* environment variables."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
settings = Settings()
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
assert settings.nats_url == "nats://10.0.0.1:4222"
assert settings.master_key_path == Path("/tmp/test.key")
assert settings.log_level == "DEBUG"
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Default values are used when env vars not set."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
# Clear any existing env vars that might interfere
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
settings = Settings()
assert settings.nats_url == "nats://localhost:4222"
assert settings.master_key_path == Path("/etc/central/master.key")
assert settings.log_level == "INFO"
class TestSettingsFromFile:
"""Test loading settings from .env file."""
def test_reads_from_env_file(self, tmp_path: Path) -> None:
"""Settings are read from .env file when env vars not present."""
env_file = tmp_path / ".env"
env_file.write_text(
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
"CENTRAL_NATS_URL=nats://file.local:4222\n"
"CENTRAL_LOG_LEVEL=WARNING\n"
)
# Create settings pointing to the temp .env file
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
assert settings.nats_url == "nats://file.local:4222"
assert settings.log_level == "WARNING"
def test_env_vars_override_file(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Environment variables take precedence over .env file."""
env_file = tmp_path / ".env"
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://env@localhost/envdb"
class TestSettingsValidation:
"""Test settings validation and error handling."""
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear error when required CENTRAL_DB_DSN is missing."""
# Ensure no env vars or .env file provides the DSN
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
with pytest.raises(Exception) as exc_info:
# Use a non-existent .env file path to ensure no fallback
Settings(_env_file=Path("/nonexistent/.env"))
# pydantic-settings raises ValidationError for missing required fields
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Invalid log level values are rejected."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
with pytest.raises(Exception):
Settings()
class TestGetSettings:
"""Test the cached settings loader."""
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_settings() returns cached instance."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""cache_clear() forces reload on next call."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
get_settings.cache_clear()
s2 = get_settings()
assert s1.db_dsn != s2.db_dsn

339
tests/test_config_store.py Normal file
View file

@ -0,0 +1,339 @@
"""Tests for database-backed configuration store.
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
environment variable to override the default test database connection.
"""
import asyncio
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN - uses central_test database with well-known test password.
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
# Create schema if not exists
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
# Create tables if not exist
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
)
""")
# Create notify function with proper key detection
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
# Create triggers if not exist
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
# Clean tables
await db_conn.execute("DELETE FROM config.adapters")
await db_conn.execute("DELETE FROM config.api_keys")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestAdapterConfig:
"""Tests for adapter configuration operations."""
@pytest.mark.asyncio
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
"""Can insert and retrieve adapter config."""
await config_store.upsert_adapter(
name="test_adapter",
enabled=True,
cadence_s=120,
settings={"key": "value"},
)
adapter = await config_store.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.enabled is True
assert adapter.cadence_s == 120
assert adapter.settings == {"key": "value"}
@pytest.mark.asyncio
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
"""Getting nonexistent adapter returns None."""
adapter = await config_store.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_list_adapters(self, config_store: ConfigStore) -> None:
"""Can list all adapters."""
await config_store.upsert_adapter("adapter_a", True, 60, {})
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
adapters = await config_store.list_adapters()
assert len(adapters) == 2
names = [a.name for a in adapters]
assert "adapter_a" in names
assert "adapter_b" in names
@pytest.mark.asyncio
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
"""Upsert updates existing adapter."""
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
adapter = await config_store.get_adapter("updater")
assert adapter is not None
assert adapter.enabled is False
assert adapter.cadence_s == 120
assert adapter.settings == {"v": 2}
@pytest.mark.asyncio
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
"""Can pause and unpause adapter."""
await config_store.upsert_adapter("pausable", True, 60, {})
await config_store.pause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is True
await config_store.unpause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is False
class TestApiKeys:
"""Tests for API key operations."""
@pytest.mark.asyncio
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
"""Can store and retrieve encrypted API key."""
await config_store.set_api_key("test_key", "super_secret_value")
value = await config_store.get_api_key("test_key")
assert value == "super_secret_value"
@pytest.mark.asyncio
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
"""Getting nonexistent key returns None."""
value = await config_store.get_api_key("does_not_exist")
assert value is None
@pytest.mark.asyncio
async def test_key_rotation(self, config_store: ConfigStore) -> None:
"""Updating key sets rotated_at."""
await config_store.set_api_key("rotate_me", "value1")
await config_store.set_api_key("rotate_me", "value2")
value = await config_store.get_api_key("rotate_me")
assert value == "value2"
@pytest.mark.asyncio
async def test_delete_key(self, config_store: ConfigStore) -> None:
"""Can delete API key."""
await config_store.set_api_key("delete_me", "value")
deleted = await config_store.delete_api_key("delete_me")
assert deleted is True
value = await config_store.get_api_key("delete_me")
assert value is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
"""Deleting nonexistent key returns False."""
deleted = await config_store.delete_api_key("never_existed")
assert deleted is False
class TestNotifications:
"""Tests for LISTEN/NOTIFY functionality."""
@pytest.mark.asyncio
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when adapter is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start listener in background
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
# Give listener time to subscribe
await asyncio.sleep(0.1)
# Trigger a change
await config_store.upsert_adapter("notify_test", True, 60, {})
# Wait for notification (with timeout)
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "adapters"
assert notifications[0][1] == "notify_test"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when API key is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
await asyncio.sleep(0.1)
await config_store.set_api_key("notify_key", "secret")
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "api_keys"
assert notifications[0][1] == "notify_key"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
class TestListenerReconnect:
"""Tests for listener reconnection on connection loss."""
@pytest.mark.asyncio
async def test_listener_cancellation_propagates(
self, config_store: ConfigStore
) -> None:
"""Cancellation cleanly stops the listener without reconnect loop."""
async def callback(table: str, key: str) -> None:
pass
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
# Give listener time to start
await asyncio.sleep(0.1)
# Cancel and verify it stops
listen_task.cancel()
try:
await asyncio.wait_for(listen_task, timeout=2.0)
except asyncio.CancelledError:
pass # Expected
except asyncio.TimeoutError:
pytest.fail("Listener did not stop after cancellation")
assert listen_task.cancelled() or listen_task.done()

175
tests/test_crypto.py Normal file
View file

@ -0,0 +1,175 @@
"""Tests for cryptographic primitives."""
import base64
import os
from pathlib import Path
import pytest
from central.crypto import (
KEY_SIZE,
DecryptionError,
KeyLoadError,
clear_key_cache,
decrypt,
encrypt,
)
@pytest.fixture
def master_key(tmp_path: Path) -> Path:
"""Create a valid master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "master.key"
key_path.write_text(base64.b64encode(key).decode())
clear_key_cache()
return key_path
@pytest.fixture
def wrong_key(tmp_path: Path) -> Path:
"""Create a different master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "wrong.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
class TestEncryptDecrypt:
"""Test encrypt/decrypt round-trip."""
def test_round_trip(self, master_key: Path) -> None:
"""Encrypting then decrypting returns original plaintext."""
plaintext = b"Hello, Central!"
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_empty(self, master_key: Path) -> None:
"""Empty plaintext encrypts and decrypts correctly."""
plaintext = b""
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_large(self, master_key: Path) -> None:
"""Large plaintext encrypts and decrypts correctly."""
plaintext = os.urandom(1024 * 1024) # 1MB
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
"""Same plaintext produces different ciphertext (random nonce)."""
plaintext = b"test"
ct1 = encrypt(plaintext, key_path=master_key)
ct2 = encrypt(plaintext, key_path=master_key)
assert ct1 != ct2
# But both decrypt to same plaintext
assert decrypt(ct1, key_path=master_key) == plaintext
assert decrypt(ct2, key_path=master_key) == plaintext
class TestDecryptionFailures:
"""Test AEAD authentication catches tampering."""
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
"""Decryption with wrong key raises DecryptionError."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
clear_key_cache() # Clear cache so wrong_key is loaded
with pytest.raises(DecryptionError):
decrypt(ciphertext, key_path=wrong_key)
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
"""Modified ciphertext is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the ciphertext (after nonce, before tag)
tampered = bytearray(ciphertext)
tampered[15] ^= 0x01 # Flip one bit
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_tampered_tag_fails(self, master_key: Path) -> None:
"""Modified authentication tag is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the last byte (part of the tag)
tampered = bytearray(ciphertext)
tampered[-1] ^= 0x01
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
"""Truncated ciphertext is rejected."""
ciphertext = b"tooshort"
with pytest.raises(DecryptionError, match="too short"):
decrypt(ciphertext, key_path=master_key)
class TestKeyLoading:
"""Test master key loading."""
def test_missing_key_file(self, tmp_path: Path) -> None:
"""Missing key file raises KeyLoadError."""
clear_key_cache()
missing = tmp_path / "nonexistent.key"
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=missing)
def test_invalid_key_size(self, tmp_path: Path) -> None:
"""Key file with wrong size raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text(base64.b64encode(b"tooshort").decode())
with pytest.raises(KeyLoadError, match="Invalid master key size"):
encrypt(b"test", key_path=bad_key)
def test_invalid_base64(self, tmp_path: Path) -> None:
"""Invalid base64 in key file raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text("not valid base64!!!")
with pytest.raises(KeyLoadError):
encrypt(b"test", key_path=bad_key)
def test_key_cached(self, master_key: Path) -> None:
"""Key is cached after first load."""
# First encryption loads the key
encrypt(b"test1", key_path=master_key)
# Delete the file
master_key.unlink()
# Second encryption should still work (cached)
ciphertext = encrypt(b"test2", key_path=master_key)
assert len(ciphertext) > 0
def test_cache_clear(self, master_key: Path) -> None:
"""clear_key_cache forces reload."""
encrypt(b"test", key_path=master_key)
master_key.unlink()
clear_key_cache()
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=master_key)

View file

@ -1,161 +1,161 @@
"""Smoke tests for Central models and CloudEvents wire format.""" """Smoke tests for Central models and CloudEvents wire format."""
from datetime import datetime, timezone from datetime import datetime, timezone
import pytest import pytest
from central.models import Event, Geo, subject_for_event from central.models import Event, Geo, subject_for_event
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
from central.cloudevents_wire import wrap_event from central.cloudevents_wire import wrap_event
@pytest.fixture @pytest.fixture
def sample_geo() -> Geo: def sample_geo() -> Geo:
"""Sample Geo object for testing.""" """Sample Geo object for testing."""
return Geo( return Geo(
centroid=(-116.2, 43.6), centroid=(-116.2, 43.6),
bbox=(-116.5, 43.4, -115.9, 43.8), bbox=(-116.5, 43.4, -115.9, 43.8),
regions=["US-ID-Ada", "US-ID-Canyon"], regions=["US-ID-Ada", "US-ID-Canyon"],
primary_region="US-ID-Ada", primary_region="US-ID-Ada",
) )
@pytest.fixture @pytest.fixture
def sample_event(sample_geo: Geo) -> Event: def sample_event(sample_geo: Geo) -> Event:
"""Sample Event object for testing.""" """Sample Event object for testing."""
return Event( return Event(
id="urn:central:nws:alert:KBOI-202401151200-SVR", id="urn:central:nws:alert:KBOI-202401151200-SVR",
source="central/adapters/nws", source="central/adapters/nws",
category="wx.alert.severe_thunderstorm_warning", category="wx.alert.severe_thunderstorm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc), time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc), expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
severity=3, severity=3,
geo=sample_geo, geo=sample_geo,
data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"}, data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"},
) )
@pytest.fixture @pytest.fixture
def sample_config() -> Config: def sample_config() -> Config:
"""Sample Config object for testing.""" """Sample Config object for testing."""
return Config( return Config(
adapters={ adapters={
"nws": NWSAdapterConfig( "nws": NWSAdapterConfig(
enabled=True, enabled=True,
cadence_s=60, cadence_s=60,
states=["ID", "MT"], states=["ID", "MT"],
contact_email="test@example.com", contact_email="test@example.com",
) )
}, },
cloudevents=CloudEventsConfig( cloudevents=CloudEventsConfig(
type_prefix="central", type_prefix="central",
source="central.echo6.mesh", source="central.local",
schema_version="1.0", schema_version="1.0",
), ),
nats=NATSConfig(url="nats://localhost:4222"), nats=NATSConfig(url="nats://localhost:4222"),
postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"), postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"),
) )
class TestSubjectForEvent: class TestSubjectForEvent:
"""Tests for subject_for_event helper.""" """Tests for subject_for_event helper."""
def test_county_subject(self, sample_event: Event) -> None: def test_county_subject(self, sample_event: Event) -> None:
"""County codes produce county subject.""" """County codes produce county subject."""
subject = subject_for_event(sample_event) subject = subject_for_event(sample_event)
assert subject == "central.wx.alert.us.id.county.ada" assert subject == "central.wx.alert.us.id.county.ada"
def test_zone_subject(self, sample_geo: Geo) -> None: def test_zone_subject(self, sample_geo: Geo) -> None:
"""Zone codes produce zone subject.""" """Zone codes produce zone subject."""
geo = Geo( geo = Geo(
centroid=sample_geo.centroid, centroid=sample_geo.centroid,
bbox=sample_geo.bbox, bbox=sample_geo.bbox,
regions=["US-ID-Z033"], regions=["US-ID-Z033"],
primary_region="US-ID-Z033", primary_region="US-ID-Z033",
) )
event = Event( event = Event(
id="test-zone", id="test-zone",
source="test", source="test",
category="wx.alert.winter_storm_warning", category="wx.alert.winter_storm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc), time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo, geo=geo,
data={}, data={},
) )
subject = subject_for_event(event) subject = subject_for_event(event)
assert subject == "central.wx.alert.us.id.zone.z033" assert subject == "central.wx.alert.us.id.zone.z033"
def test_unknown_subject(self, sample_event: Event) -> None: def test_unknown_subject(self, sample_event: Event) -> None:
"""Missing primary_region produces unknown subject.""" """Missing primary_region produces unknown subject."""
geo = Geo(regions=[], primary_region=None) geo = Geo(regions=[], primary_region=None)
event = Event( event = Event(
id="test-unknown", id="test-unknown",
source="test", source="test",
category="wx.alert.test", category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc), time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo, geo=geo,
data={}, data={},
) )
subject = subject_for_event(event) subject = subject_for_event(event)
assert subject == "central.wx.alert.us.unknown" assert subject == "central.wx.alert.us.unknown"
def test_custom_prefix(self, sample_event: Event) -> None: def test_custom_prefix(self, sample_event: Event) -> None:
"""Custom prefix is used in subject.""" """Custom prefix is used in subject."""
subject = subject_for_event(sample_event, prefix="myapp.events") subject = subject_for_event(sample_event, prefix="myapp.events")
assert subject == "myapp.events.alert.us.id.county.ada" assert subject == "myapp.events.alert.us.id.county.ada"
class TestCloudEventsWire: class TestCloudEventsWire:
"""Tests for CloudEvents wire format.""" """Tests for CloudEvents wire format."""
def test_required_fields_present( def test_required_fields_present(
self, sample_event: Event, sample_config: Config self, sample_event: Event, sample_config: Config
) -> None: ) -> None:
"""Required CloudEvents fields are present.""" """Required CloudEvents fields are present."""
envelope, msg_id = wrap_event(sample_event, sample_config) envelope, msg_id = wrap_event(sample_event, sample_config)
assert msg_id == sample_event.id assert msg_id == sample_event.id
assert envelope["id"] == sample_event.id assert envelope["id"] == sample_event.id
assert envelope["source"] == sample_config.cloudevents.source assert envelope["source"] == sample_config.cloudevents.source
assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1" assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1"
assert envelope["specversion"] == "1.0" assert envelope["specversion"] == "1.0"
assert "time" in envelope assert "time" in envelope
assert envelope["datacontenttype"] == "application/json" assert envelope["datacontenttype"] == "application/json"
assert "data" in envelope assert "data" in envelope
def test_extension_attributes_lowercase( def test_extension_attributes_lowercase(
self, sample_event: Event, sample_config: Config self, sample_event: Event, sample_config: Config
) -> None: ) -> None:
"""Extension attributes are lowercase with no underscores.""" """Extension attributes are lowercase with no underscores."""
envelope, _ = wrap_event(sample_event, sample_config) envelope, _ = wrap_event(sample_event, sample_config)
# Check that extension attributes exist and are lowercase # Check that extension attributes exist and are lowercase
assert envelope["centralschemaversion"] == "1.0" assert envelope["centralschemaversion"] == "1.0"
assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning" assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning"
assert envelope["centralseverity"] == 3 assert envelope["centralseverity"] == 3
# Verify no uppercase or underscores in extension names # Verify no uppercase or underscores in extension names
for key in ["centralschemaversion", "centralcategory", "centralseverity"]: for key in ["centralschemaversion", "centralcategory", "centralseverity"]:
assert key.islower() assert key.islower()
assert "_" not in key assert "_" not in key
def test_severity_none_omits_centralseverity( def test_severity_none_omits_centralseverity(
self, sample_geo: Geo, sample_config: Config self, sample_geo: Geo, sample_config: Config
) -> None: ) -> None:
"""When severity is None, centralseverity is omitted entirely.""" """When severity is None, centralseverity is omitted entirely."""
event = Event( event = Event(
id="test-no-severity", id="test-no-severity",
source="test", source="test",
category="wx.alert.test", category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc), time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
severity=None, # Explicitly None severity=None, # Explicitly None
geo=sample_geo, geo=sample_geo,
data={}, data={},
) )
envelope, _ = wrap_event(event, sample_config) envelope, _ = wrap_event(event, sample_config)
# centralseverity should not be present at all # centralseverity should not be present at all
assert "centralseverity" not in envelope assert "centralseverity" not in envelope
# Other extensions should still be present # Other extensions should still be present
assert "centralschemaversion" in envelope assert "centralschemaversion" in envelope
assert "centralcategory" in envelope assert "centralcategory" in envelope