Merge pull request #1 from zvx-echo6/feature/1a-config-storage

feat(config): Phase 1a-2 config storage primitives
This commit is contained in:
malice 2026-05-15 19:49:21 -06:00 committed by GitHub
commit ee081c9bc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1570 additions and 161 deletions

View file

@ -15,14 +15,18 @@ dependencies = [
"aiolimiter>=1.2.1", "aiolimiter>=1.2.1",
"asyncpg>=0.31.0", "asyncpg>=0.31.0",
"cloudevents>=2.0.0", "cloudevents>=2.0.0",
"cryptography>=44.0.0",
"nats-py>=2.14.0", "nats-py>=2.14.0",
"pydantic>=2,<3", "pydantic>=2,<3",
"pydantic-settings>=2.7.0",
"tenacity>=9.1.4", "tenacity>=9.1.4",
] ]
[project.scripts] [project.scripts]
central-supervisor = "central.supervisor:main" central-supervisor = "central.supervisor:main"
central-archive = "central.archive:main" central-archive = "central.archive:main"
central-migrate = "central.migrate:main"
central-cli = "central.cli:main"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/central"] packages = ["src/central"]

View file

@ -0,0 +1,64 @@
-- Migration: 001_create_config_schema
-- Creates the config schema with adapters and api_keys tables.
-- Also seeds the NWS adapter row from current TOML config.
-- Create config schema
CREATE SCHEMA config;
-- Adapters configuration table
CREATE TABLE config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- API keys table (encrypted values)
CREATE TABLE config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
);
-- Notify function for config changes
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
-- Handle different table structures
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql;
-- Trigger for adapters table
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
-- Trigger for api_keys table
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
-- Seed NWS adapter from current TOML config values
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES (
'nws',
true,
60,
'{"states": ["ID", "OR", "WA", "MT", "WY", "UT", "NV"], "contact_email": "mj@k7zvx.com"}'::jsonb
);

View file

@ -0,0 +1,21 @@
-- Migration: 002_add_updated_at_trigger_and_index
-- Adds auto-update trigger for updated_at column on adapters table
-- Adds partial index for efficient enabled adapter queries
-- Auto-update trigger for updated_at
CREATE OR REPLACE FUNCTION config.set_updated_at()
RETURNS trigger AS $$
BEGIN
NEW.updated_at := now();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER adapters_set_updated_at
BEFORE UPDATE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.set_updated_at();
-- Partial index for enabled adapters (common query pattern)
CREATE INDEX adapters_enabled_idx
ON config.adapters (enabled)
WHERE enabled = true;

View file

@ -0,0 +1,46 @@
"""Bootstrap configuration from environment variables.
This module provides early-stage configuration loading from environment
variables or a .env file. Used before the database-backed config store
is available.
"""
from functools import lru_cache
from pathlib import Path
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Bootstrap settings loaded from environment or .env file."""
model_config = SettingsConfigDict(
env_prefix="CENTRAL_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
db_dsn: str = Field(description="PostgreSQL connection string")
nats_url: str = Field(default="nats://localhost:4222", description="NATS server URL")
master_key_path: Path = Field(
default=Path("/etc/central/master.key"),
description="Path to AES-256 master key file",
)
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
default="INFO",
description="Logging level",
)
@lru_cache
def get_settings(env_file: Path | None = None) -> Settings:
"""Load settings, optionally from a specific .env file.
Results are cached. Call get_settings.cache_clear() to reload.
"""
if env_file is not None:
return Settings(_env_file=env_file)
return Settings()

75
src/central/cli.py Normal file
View file

@ -0,0 +1,75 @@
"""Central CLI commands."""
import argparse
import asyncio
import sys
async def config_store_check() -> int:
"""Smoke test for config store connectivity.
Connects via bootstrap_config, lists adapters, and verifies crypto.
Returns 0 on success, 1 on failure.
"""
from central.bootstrap_config import get_settings
from central.config_store import ConfigStore
from central.crypto import decrypt, encrypt
settings = get_settings()
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
try:
store = await ConfigStore.create(settings.db_dsn)
except Exception as e:
print(f"ERROR: Failed to connect to database: {e}")
return 1
try:
# List adapters
adapters = await store.list_adapters()
print(f"\nAdapters ({len(adapters)}):")
for adapter in adapters:
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
print(f" settings: {adapter.settings}")
# Test crypto
test_plaintext = b"config_store_check_test"
try:
ciphertext = encrypt(test_plaintext)
decrypted = decrypt(ciphertext)
if decrypted == test_plaintext:
print("\ncrypto: ok")
else:
print("\ncrypto: FAILED (round-trip mismatch)")
return 1
except Exception as e:
print(f"\ncrypto: FAILED ({e})")
return 1
print("\nAll checks passed.")
return 0
finally:
await store.close()
def main_config_store_check() -> None:
"""Entry point for central-cli config-store-check."""
sys.exit(asyncio.run(config_store_check()))
def main() -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="Central CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("config-store-check", help="Test config store connectivity")
args = parser.parse_args()
if args.command == "config-store-check":
main_config_store_check()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,39 @@
"""Pydantic models for database-backed configuration."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
class AdapterConfig(BaseModel):
"""Configuration for a single adapter."""
name: str = Field(description="Unique adapter identifier")
enabled: bool = Field(default=True, description="Whether adapter is active")
cadence_s: int = Field(description="Poll interval in seconds")
settings: dict[str, Any] = Field(
default_factory=dict, description="Adapter-specific settings"
)
paused_at: datetime | None = Field(
default=None, description="When adapter was paused, if paused"
)
updated_at: datetime = Field(description="Last configuration update time")
@property
def is_paused(self) -> bool:
"""Check if adapter is currently paused."""
return self.paused_at is not None
class ApiKeyInfo(BaseModel):
"""Metadata about an API key (without the decrypted value)."""
alias: str = Field(description="Key identifier/alias")
created_at: datetime = Field(description="When key was created")
rotated_at: datetime | None = Field(
default=None, description="Last rotation time"
)
last_used_at: datetime | None = Field(
default=None, description="Last usage time"
)

269
src/central/config_store.py Normal file
View file

@ -0,0 +1,269 @@
"""Database-backed configuration store.
Provides async access to the config schema tables with support for
Postgres LISTEN/NOTIFY for real-time config change notifications.
"""
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
import asyncpg
from central.config_models import AdapterConfig
from central.crypto import decrypt, encrypt
logger = logging.getLogger(__name__)
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
"""Set up JSON codec for asyncpg connection."""
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
class ConfigStore:
"""Async interface to the config schema in Postgres."""
def __init__(self, pool: asyncpg.Pool) -> None:
self._pool = pool
@classmethod
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
"""Create a ConfigStore with a new connection pool."""
pool = await asyncpg.create_pool(
dsn,
min_size=min_size,
max_size=max_size,
init=_setup_json_codec,
)
return cls(pool)
async def close(self) -> None:
"""Close the connection pool."""
await self._pool.close()
# -------------------------------------------------------------------------
# Adapter configuration
# -------------------------------------------------------------------------
async def get_adapter(self, name: str) -> AdapterConfig | None:
"""Get configuration for a specific adapter."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
WHERE name = $1
""",
name,
)
if row is None:
return None
return AdapterConfig(**dict(row))
async def list_adapters(self) -> list[AdapterConfig]:
"""List all configured adapters."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
FROM config.adapters
ORDER BY name
"""
)
return [AdapterConfig(**dict(row)) for row in rows]
async def upsert_adapter(
self,
name: str,
enabled: bool,
cadence_s: int,
settings: dict[str, Any],
) -> None:
"""Insert or update an adapter configuration."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (name) DO UPDATE SET
enabled = EXCLUDED.enabled,
cadence_s = EXCLUDED.cadence_s,
settings = EXCLUDED.settings,
updated_at = now()
""",
name,
enabled,
cadence_s,
settings, # Will be encoded as JSON by the codec
)
async def pause_adapter(self, name: str) -> None:
"""Pause an adapter by setting paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = now(), updated_at = now()
WHERE name = $1
""",
name,
)
async def unpause_adapter(self, name: str) -> None:
"""Unpause an adapter by clearing paused_at."""
async with self._pool.acquire() as conn:
await conn.execute(
"""
UPDATE config.adapters
SET paused_at = NULL, updated_at = now()
WHERE name = $1
""",
name,
)
# -------------------------------------------------------------------------
# API key management
# -------------------------------------------------------------------------
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
"""Store an API key, encrypting it with the master key."""
encrypted = encrypt(plaintext_value.encode("utf-8"))
async with self._pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO config.api_keys (alias, encrypted_value)
VALUES ($1, $2)
ON CONFLICT (alias) DO UPDATE SET
encrypted_value = EXCLUDED.encrypted_value,
rotated_at = now()
""",
alias,
encrypted,
)
async def get_api_key(self, alias: str) -> str | None:
"""Retrieve and decrypt an API key by alias."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
""",
alias,
)
if row is not None:
# Update last_used_at
await conn.execute(
"""
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
""",
alias,
)
if row is None:
return None
return decrypt(row["encrypted_value"]).decode("utf-8")
async def delete_api_key(self, alias: str) -> bool:
"""Delete an API key. Returns True if key existed."""
async with self._pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM config.api_keys WHERE alias = $1", alias
)
return result == "DELETE 1"
# -------------------------------------------------------------------------
# Change notifications
# -------------------------------------------------------------------------
async def listen_for_changes(
self,
callback: Callable[[str, str], Awaitable[None] | None],
) -> None:
"""Listen for config changes via Postgres NOTIFY.
Runs forever, calling callback(table, key) each time a change is
detected. The callback can be sync or async.
On connection loss, automatically reconnects with exponential backoff.
Cancellation (via task.cancel()) propagates cleanly.
Args:
callback: Function called with (table_name, row_key) on each change.
"""
backoff = 1.0
max_backoff = 30.0
while True:
conn = None
try:
conn = await self._pool.acquire()
logger.info("Config listener connected to database")
backoff = 1.0 # Reset backoff on successful connect
def notification_handler(
conn: asyncpg.Connection,
pid: int,
channel: str,
payload: str,
) -> None:
# payload format: "table_name:key"
if ":" in payload:
table, key = payload.split(":", 1)
else:
table, key = payload, ""
result = callback(table, key)
if asyncio.iscoroutine(result):
asyncio.create_task(result)
await conn.add_listener("config_changed", notification_handler)
try:
# Keep connection alive with periodic keepalive
while True:
await asyncio.sleep(60)
await conn.execute("SELECT 1")
finally:
await conn.remove_listener("config_changed", notification_handler)
except asyncio.CancelledError:
# Cancellation must propagate cleanly
logger.info("Config listener cancelled")
raise
except (
asyncpg.PostgresConnectionError,
asyncpg.InterfaceError,
ConnectionResetError,
OSError,
) as e:
logger.warning(
"Config listener connection lost, reconnecting in %.1fs: %s",
backoff,
e,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
except Exception as e:
# Unexpected error - log and retry with backoff
logger.exception(
"Config listener unexpected error, reconnecting in %.1fs",
backoff,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
finally:
if conn is not None:
try:
await self._pool.release(conn)
except Exception:
pass # Connection may already be invalid

111
src/central/crypto.py Normal file
View file

@ -0,0 +1,111 @@
"""Cryptographic primitives for secret storage.
Uses AES-256-GCM for authenticated encryption. The master key is read
from the path specified in bootstrap config on first use and cached.
"""
import base64
import os
from functools import lru_cache
from pathlib import Path
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# AES-256 requires 32-byte key
KEY_SIZE = 32
# GCM nonce size (96 bits recommended by NIST)
NONCE_SIZE = 12
class CryptoError(Exception):
"""Base exception for crypto operations."""
class KeyLoadError(CryptoError):
"""Failed to load master key."""
class DecryptionError(CryptoError):
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
@lru_cache
def _load_master_key(path: Path) -> bytes:
"""Load and decode the base64-encoded master key from file."""
try:
key_b64 = path.read_text().strip()
key = base64.b64decode(key_b64)
except FileNotFoundError:
raise KeyLoadError(f"Master key file not found: {path}")
except Exception as e:
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
if len(key) != KEY_SIZE:
raise KeyLoadError(
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
)
return key
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
"""Encrypt plaintext using AES-256-GCM.
Args:
plaintext: Data to encrypt.
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
key = _load_master_key(key_path)
nonce = os.urandom(NONCE_SIZE)
aesgcm = AESGCM(key)
# GCM appends the 16-byte tag to the ciphertext
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
return nonce + ciphertext_with_tag
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
"""Decrypt ciphertext using AES-256-GCM.
Args:
ciphertext: Data in format: nonce || ciphertext || tag
key_path: Path to master key file. If None, uses default from
bootstrap config.
Returns:
Decrypted plaintext.
Raises:
DecryptionError: If decryption fails (wrong key or tampered data).
"""
if key_path is None:
from central.bootstrap_config import get_settings
key_path = get_settings().master_key_path
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
raise DecryptionError("Ciphertext too short")
key = _load_master_key(key_path)
nonce = ciphertext[:NONCE_SIZE]
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
aesgcm = AESGCM(key)
try:
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
except Exception as e:
raise DecryptionError(f"Decryption failed: {e}")
return plaintext
def clear_key_cache() -> None:
"""Clear the cached master key. Use after key rotation."""
_load_master_key.cache_clear()

125
src/central/migrate.py Normal file
View file

@ -0,0 +1,125 @@
"""Simple database migration runner.
Tracks applied migrations in a `schema_migrations` table. Migrations are
plain SQL files in `sql/migrations/` named with numeric prefixes:
001_create_config_schema.sql
002_add_operators_table.sql
...
Usage:
central-migrate [--dry-run]
"""
import argparse
import asyncio
import sys
from pathlib import Path
import asyncpg
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
"""Create the schema_migrations table if it doesn't exist."""
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
"""Return set of already-applied migration versions."""
rows = await conn.fetch("SELECT version FROM schema_migrations")
return {row["version"] for row in rows}
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
"""Find all .sql files in migrations directory, sorted by name.
Returns list of (version, path) tuples where version is the filename
without extension.
"""
if not migrations_dir.exists():
return []
migrations = []
for f in sorted(migrations_dir.glob("*.sql")):
version = f.stem # e.g., "001_create_config_schema"
migrations.append((version, f))
return migrations
async def apply_migration(
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
) -> None:
"""Apply a single migration."""
sql = sql_path.read_text()
if dry_run:
print(f"[DRY RUN] Would apply: {version}")
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
return
async with conn.transaction():
await conn.execute(sql)
await conn.execute(
"INSERT INTO schema_migrations (version) VALUES ($1)", version
)
print(f"Applied: {version}")
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
"""Run all pending migrations.
Returns number of migrations applied.
"""
conn = await asyncpg.connect(dsn)
try:
await ensure_migrations_table(conn)
applied = await get_applied_migrations(conn)
pending = [
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
]
if not pending:
print("No pending migrations.")
return 0
print(f"Found {len(pending)} pending migration(s).")
for version, path in pending:
await apply_migration(conn, version, path, dry_run)
return len(pending)
finally:
await conn.close()
async def async_main() -> None:
"""Async entry point."""
parser = argparse.ArgumentParser(description="Run database migrations")
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be applied without executing",
)
args = parser.parse_args()
from central.bootstrap_config import get_settings
settings = get_settings()
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
if count > 0 and not args.dry_run:
print(f"Successfully applied {count} migration(s).")
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()

18
tests/README.md Normal file
View file

@ -0,0 +1,18 @@
# Central Tests
## Test Database
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
By default, tests connect to:
```
postgresql://central_test:testpass@localhost/central_test
```
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
environment variable:
```bash
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
uv run pytest tests/test_config_store.py
```

View file

@ -0,0 +1,123 @@
"""Tests for bootstrap configuration."""
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from central.bootstrap_config import Settings, get_settings
class TestSettingsFromEnv:
"""Test loading settings from environment variables."""
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Settings are read from CENTRAL_* environment variables."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
settings = Settings()
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
assert settings.nats_url == "nats://10.0.0.1:4222"
assert settings.master_key_path == Path("/tmp/test.key")
assert settings.log_level == "DEBUG"
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Default values are used when env vars not set."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
# Clear any existing env vars that might interfere
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
settings = Settings()
assert settings.nats_url == "nats://localhost:4222"
assert settings.master_key_path == Path("/etc/central/master.key")
assert settings.log_level == "INFO"
class TestSettingsFromFile:
"""Test loading settings from .env file."""
def test_reads_from_env_file(self, tmp_path: Path) -> None:
"""Settings are read from .env file when env vars not present."""
env_file = tmp_path / ".env"
env_file.write_text(
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
"CENTRAL_NATS_URL=nats://file.local:4222\n"
"CENTRAL_LOG_LEVEL=WARNING\n"
)
# Create settings pointing to the temp .env file
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
assert settings.nats_url == "nats://file.local:4222"
assert settings.log_level == "WARNING"
def test_env_vars_override_file(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Environment variables take precedence over .env file."""
env_file = tmp_path / ".env"
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://env@localhost/envdb"
class TestSettingsValidation:
"""Test settings validation and error handling."""
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear error when required CENTRAL_DB_DSN is missing."""
# Ensure no env vars or .env file provides the DSN
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
with pytest.raises(Exception) as exc_info:
# Use a non-existent .env file path to ensure no fallback
Settings(_env_file=Path("/nonexistent/.env"))
# pydantic-settings raises ValidationError for missing required fields
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Invalid log level values are rejected."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
with pytest.raises(Exception):
Settings()
class TestGetSettings:
"""Test the cached settings loader."""
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_settings() returns cached instance."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""cache_clear() forces reload on next call."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
get_settings.cache_clear()
s2 = get_settings()
assert s1.db_dsn != s2.db_dsn

339
tests/test_config_store.py Normal file
View file

@ -0,0 +1,339 @@
"""Tests for database-backed configuration store.
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
environment variable to override the default test database connection.
"""
import asyncio
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN - uses central_test database with well-known test password.
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
# Create schema if not exists
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
# Create tables if not exist
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
)
""")
# Create notify function with proper key detection
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
# Create triggers if not exist
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
# Clean tables
await db_conn.execute("DELETE FROM config.adapters")
await db_conn.execute("DELETE FROM config.api_keys")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestAdapterConfig:
"""Tests for adapter configuration operations."""
@pytest.mark.asyncio
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
"""Can insert and retrieve adapter config."""
await config_store.upsert_adapter(
name="test_adapter",
enabled=True,
cadence_s=120,
settings={"key": "value"},
)
adapter = await config_store.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.enabled is True
assert adapter.cadence_s == 120
assert adapter.settings == {"key": "value"}
@pytest.mark.asyncio
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
"""Getting nonexistent adapter returns None."""
adapter = await config_store.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_list_adapters(self, config_store: ConfigStore) -> None:
"""Can list all adapters."""
await config_store.upsert_adapter("adapter_a", True, 60, {})
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
adapters = await config_store.list_adapters()
assert len(adapters) == 2
names = [a.name for a in adapters]
assert "adapter_a" in names
assert "adapter_b" in names
@pytest.mark.asyncio
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
"""Upsert updates existing adapter."""
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
adapter = await config_store.get_adapter("updater")
assert adapter is not None
assert adapter.enabled is False
assert adapter.cadence_s == 120
assert adapter.settings == {"v": 2}
@pytest.mark.asyncio
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
"""Can pause and unpause adapter."""
await config_store.upsert_adapter("pausable", True, 60, {})
await config_store.pause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is True
await config_store.unpause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is False
class TestApiKeys:
"""Tests for API key operations."""
@pytest.mark.asyncio
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
"""Can store and retrieve encrypted API key."""
await config_store.set_api_key("test_key", "super_secret_value")
value = await config_store.get_api_key("test_key")
assert value == "super_secret_value"
@pytest.mark.asyncio
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
"""Getting nonexistent key returns None."""
value = await config_store.get_api_key("does_not_exist")
assert value is None
@pytest.mark.asyncio
async def test_key_rotation(self, config_store: ConfigStore) -> None:
"""Updating key sets rotated_at."""
await config_store.set_api_key("rotate_me", "value1")
await config_store.set_api_key("rotate_me", "value2")
value = await config_store.get_api_key("rotate_me")
assert value == "value2"
@pytest.mark.asyncio
async def test_delete_key(self, config_store: ConfigStore) -> None:
"""Can delete API key."""
await config_store.set_api_key("delete_me", "value")
deleted = await config_store.delete_api_key("delete_me")
assert deleted is True
value = await config_store.get_api_key("delete_me")
assert value is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
"""Deleting nonexistent key returns False."""
deleted = await config_store.delete_api_key("never_existed")
assert deleted is False
class TestNotifications:
"""Tests for LISTEN/NOTIFY functionality."""
@pytest.mark.asyncio
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when adapter is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start listener in background
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
# Give listener time to subscribe
await asyncio.sleep(0.1)
# Trigger a change
await config_store.upsert_adapter("notify_test", True, 60, {})
# Wait for notification (with timeout)
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "adapters"
assert notifications[0][1] == "notify_test"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when API key is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
await asyncio.sleep(0.1)
await config_store.set_api_key("notify_key", "secret")
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "api_keys"
assert notifications[0][1] == "notify_key"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
class TestListenerReconnect:
"""Tests for listener reconnection on connection loss."""
@pytest.mark.asyncio
async def test_listener_cancellation_propagates(
self, config_store: ConfigStore
) -> None:
"""Cancellation cleanly stops the listener without reconnect loop."""
async def callback(table: str, key: str) -> None:
pass
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
# Give listener time to start
await asyncio.sleep(0.1)
# Cancel and verify it stops
listen_task.cancel()
try:
await asyncio.wait_for(listen_task, timeout=2.0)
except asyncio.CancelledError:
pass # Expected
except asyncio.TimeoutError:
pytest.fail("Listener did not stop after cancellation")
assert listen_task.cancelled() or listen_task.done()

175
tests/test_crypto.py Normal file
View file

@ -0,0 +1,175 @@
"""Tests for cryptographic primitives."""
import base64
import os
from pathlib import Path
import pytest
from central.crypto import (
KEY_SIZE,
DecryptionError,
KeyLoadError,
clear_key_cache,
decrypt,
encrypt,
)
@pytest.fixture
def master_key(tmp_path: Path) -> Path:
"""Create a valid master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "master.key"
key_path.write_text(base64.b64encode(key).decode())
clear_key_cache()
return key_path
@pytest.fixture
def wrong_key(tmp_path: Path) -> Path:
"""Create a different master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "wrong.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
class TestEncryptDecrypt:
"""Test encrypt/decrypt round-trip."""
def test_round_trip(self, master_key: Path) -> None:
"""Encrypting then decrypting returns original plaintext."""
plaintext = b"Hello, Central!"
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_empty(self, master_key: Path) -> None:
"""Empty plaintext encrypts and decrypts correctly."""
plaintext = b""
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_large(self, master_key: Path) -> None:
"""Large plaintext encrypts and decrypts correctly."""
plaintext = os.urandom(1024 * 1024) # 1MB
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
"""Same plaintext produces different ciphertext (random nonce)."""
plaintext = b"test"
ct1 = encrypt(plaintext, key_path=master_key)
ct2 = encrypt(plaintext, key_path=master_key)
assert ct1 != ct2
# But both decrypt to same plaintext
assert decrypt(ct1, key_path=master_key) == plaintext
assert decrypt(ct2, key_path=master_key) == plaintext
class TestDecryptionFailures:
"""Test AEAD authentication catches tampering."""
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
"""Decryption with wrong key raises DecryptionError."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
clear_key_cache() # Clear cache so wrong_key is loaded
with pytest.raises(DecryptionError):
decrypt(ciphertext, key_path=wrong_key)
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
"""Modified ciphertext is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the ciphertext (after nonce, before tag)
tampered = bytearray(ciphertext)
tampered[15] ^= 0x01 # Flip one bit
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_tampered_tag_fails(self, master_key: Path) -> None:
"""Modified authentication tag is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the last byte (part of the tag)
tampered = bytearray(ciphertext)
tampered[-1] ^= 0x01
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
"""Truncated ciphertext is rejected."""
ciphertext = b"tooshort"
with pytest.raises(DecryptionError, match="too short"):
decrypt(ciphertext, key_path=master_key)
class TestKeyLoading:
"""Test master key loading."""
def test_missing_key_file(self, tmp_path: Path) -> None:
"""Missing key file raises KeyLoadError."""
clear_key_cache()
missing = tmp_path / "nonexistent.key"
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=missing)
def test_invalid_key_size(self, tmp_path: Path) -> None:
"""Key file with wrong size raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text(base64.b64encode(b"tooshort").decode())
with pytest.raises(KeyLoadError, match="Invalid master key size"):
encrypt(b"test", key_path=bad_key)
def test_invalid_base64(self, tmp_path: Path) -> None:
"""Invalid base64 in key file raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text("not valid base64!!!")
with pytest.raises(KeyLoadError):
encrypt(b"test", key_path=bad_key)
def test_key_cached(self, master_key: Path) -> None:
"""Key is cached after first load."""
# First encryption loads the key
encrypt(b"test1", key_path=master_key)
# Delete the file
master_key.unlink()
# Second encryption should still work (cached)
ciphertext = encrypt(b"test2", key_path=master_key)
assert len(ciphertext) > 0
def test_cache_clear(self, master_key: Path) -> None:
"""clear_key_cache forces reload."""
encrypt(b"test", key_path=master_key)
master_key.unlink()
clear_key_cache()
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=master_key)

View file

@ -49,7 +49,7 @@ def sample_config() -> Config:
}, },
cloudevents=CloudEventsConfig( cloudevents=CloudEventsConfig(
type_prefix="central", type_prefix="central",
source="central.echo6.mesh", source="central.local",
schema_version="1.0", schema_version="1.0",
), ),
nats=NATSConfig(url="nats://localhost:4222"), nats=NATSConfig(url="nats://localhost:4222"),