mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
Merge pull request #1 from zvx-echo6/feature/1a-config-storage
feat(config): Phase 1a-2 config storage primitives
This commit is contained in:
commit
ee081c9bc2
14 changed files with 1570 additions and 161 deletions
|
|
@ -15,14 +15,18 @@ dependencies = [
|
||||||
"aiolimiter>=1.2.1",
|
"aiolimiter>=1.2.1",
|
||||||
"asyncpg>=0.31.0",
|
"asyncpg>=0.31.0",
|
||||||
"cloudevents>=2.0.0",
|
"cloudevents>=2.0.0",
|
||||||
|
"cryptography>=44.0.0",
|
||||||
"nats-py>=2.14.0",
|
"nats-py>=2.14.0",
|
||||||
"pydantic>=2,<3",
|
"pydantic>=2,<3",
|
||||||
|
"pydantic-settings>=2.7.0",
|
||||||
"tenacity>=9.1.4",
|
"tenacity>=9.1.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
central-supervisor = "central.supervisor:main"
|
central-supervisor = "central.supervisor:main"
|
||||||
central-archive = "central.archive:main"
|
central-archive = "central.archive:main"
|
||||||
|
central-migrate = "central.migrate:main"
|
||||||
|
central-cli = "central.cli:main"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/central"]
|
packages = ["src/central"]
|
||||||
|
|
|
||||||
64
sql/migrations/001_create_config_schema.sql
Normal file
64
sql/migrations/001_create_config_schema.sql
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
-- Migration: 001_create_config_schema
|
||||||
|
-- Creates the config schema with adapters and api_keys tables.
|
||||||
|
-- Also seeds the NWS adapter row from current TOML config.
|
||||||
|
|
||||||
|
-- Create config schema
|
||||||
|
CREATE SCHEMA config;
|
||||||
|
|
||||||
|
-- Adapters configuration table
|
||||||
|
CREATE TABLE config.adapters (
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
cadence_s INTEGER NOT NULL,
|
||||||
|
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
paused_at TIMESTAMPTZ,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- API keys table (encrypted values)
|
||||||
|
CREATE TABLE config.api_keys (
|
||||||
|
alias TEXT PRIMARY KEY,
|
||||||
|
encrypted_value BYTEA NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
rotated_at TIMESTAMPTZ,
|
||||||
|
last_used_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Notify function for config changes
|
||||||
|
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||||
|
RETURNS trigger AS $$
|
||||||
|
DECLARE
|
||||||
|
key_value TEXT;
|
||||||
|
BEGIN
|
||||||
|
-- Handle different table structures
|
||||||
|
IF TG_TABLE_NAME = 'adapters' THEN
|
||||||
|
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||||
|
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||||
|
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||||
|
ELSE
|
||||||
|
key_value := '';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||||
|
RETURN COALESCE(NEW, OLD);
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
-- Trigger for adapters table
|
||||||
|
CREATE TRIGGER adapters_notify
|
||||||
|
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
|
||||||
|
|
||||||
|
-- Trigger for api_keys table
|
||||||
|
CREATE TRIGGER api_keys_notify
|
||||||
|
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
|
||||||
|
|
||||||
|
-- Seed NWS adapter from current TOML config values
|
||||||
|
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||||
|
VALUES (
|
||||||
|
'nws',
|
||||||
|
true,
|
||||||
|
60,
|
||||||
|
'{"states": ["ID", "OR", "WA", "MT", "WY", "UT", "NV"], "contact_email": "mj@k7zvx.com"}'::jsonb
|
||||||
|
);
|
||||||
21
sql/migrations/002_add_updated_at_trigger_and_index.sql
Normal file
21
sql/migrations/002_add_updated_at_trigger_and_index.sql
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
-- Migration: 002_add_updated_at_trigger_and_index
|
||||||
|
-- Adds auto-update trigger for updated_at column on adapters table
|
||||||
|
-- Adds partial index for efficient enabled adapter queries
|
||||||
|
|
||||||
|
-- Auto-update trigger for updated_at
|
||||||
|
CREATE OR REPLACE FUNCTION config.set_updated_at()
|
||||||
|
RETURNS trigger AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.updated_at := now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
CREATE TRIGGER adapters_set_updated_at
|
||||||
|
BEFORE UPDATE ON config.adapters
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION config.set_updated_at();
|
||||||
|
|
||||||
|
-- Partial index for enabled adapters (common query pattern)
|
||||||
|
CREATE INDEX adapters_enabled_idx
|
||||||
|
ON config.adapters (enabled)
|
||||||
|
WHERE enabled = true;
|
||||||
46
src/central/bootstrap_config.py
Normal file
46
src/central/bootstrap_config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
"""Bootstrap configuration from environment variables.
|
||||||
|
|
||||||
|
This module provides early-stage configuration loading from environment
|
||||||
|
variables or a .env file. Used before the database-backed config store
|
||||||
|
is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Bootstrap settings loaded from environment or .env file."""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="CENTRAL_",
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_dsn: str = Field(description="PostgreSQL connection string")
|
||||||
|
nats_url: str = Field(default="nats://localhost:4222", description="NATS server URL")
|
||||||
|
master_key_path: Path = Field(
|
||||||
|
default=Path("/etc/central/master.key"),
|
||||||
|
description="Path to AES-256 master key file",
|
||||||
|
)
|
||||||
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||||
|
default="INFO",
|
||||||
|
description="Logging level",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings(env_file: Path | None = None) -> Settings:
|
||||||
|
"""Load settings, optionally from a specific .env file.
|
||||||
|
|
||||||
|
Results are cached. Call get_settings.cache_clear() to reload.
|
||||||
|
"""
|
||||||
|
if env_file is not None:
|
||||||
|
return Settings(_env_file=env_file)
|
||||||
|
return Settings()
|
||||||
75
src/central/cli.py
Normal file
75
src/central/cli.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""Central CLI commands."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
async def config_store_check() -> int:
|
||||||
|
"""Smoke test for config store connectivity.
|
||||||
|
|
||||||
|
Connects via bootstrap_config, lists adapters, and verifies crypto.
|
||||||
|
Returns 0 on success, 1 on failure.
|
||||||
|
"""
|
||||||
|
from central.bootstrap_config import get_settings
|
||||||
|
from central.config_store import ConfigStore
|
||||||
|
from central.crypto import decrypt, encrypt
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
|
||||||
|
|
||||||
|
try:
|
||||||
|
store = await ConfigStore.create(settings.db_dsn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Failed to connect to database: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# List adapters
|
||||||
|
adapters = await store.list_adapters()
|
||||||
|
print(f"\nAdapters ({len(adapters)}):")
|
||||||
|
for adapter in adapters:
|
||||||
|
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
|
||||||
|
print(f" settings: {adapter.settings}")
|
||||||
|
|
||||||
|
# Test crypto
|
||||||
|
test_plaintext = b"config_store_check_test"
|
||||||
|
try:
|
||||||
|
ciphertext = encrypt(test_plaintext)
|
||||||
|
decrypted = decrypt(ciphertext)
|
||||||
|
if decrypted == test_plaintext:
|
||||||
|
print("\ncrypto: ok")
|
||||||
|
else:
|
||||||
|
print("\ncrypto: FAILED (round-trip mismatch)")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\ncrypto: FAILED ({e})")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print("\nAll checks passed.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main_config_store_check() -> None:
|
||||||
|
"""Entry point for central-cli config-store-check."""
|
||||||
|
sys.exit(asyncio.run(config_store_check()))
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main CLI entry point."""
|
||||||
|
parser = argparse.ArgumentParser(description="Central CLI")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
subparsers.add_parser("config-store-check", help="Test config store connectivity")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "config-store-check":
|
||||||
|
main_config_store_check()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
39
src/central/config_models.py
Normal file
39
src/central/config_models.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Pydantic models for database-backed configuration."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterConfig(BaseModel):
|
||||||
|
"""Configuration for a single adapter."""
|
||||||
|
|
||||||
|
name: str = Field(description="Unique adapter identifier")
|
||||||
|
enabled: bool = Field(default=True, description="Whether adapter is active")
|
||||||
|
cadence_s: int = Field(description="Poll interval in seconds")
|
||||||
|
settings: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Adapter-specific settings"
|
||||||
|
)
|
||||||
|
paused_at: datetime | None = Field(
|
||||||
|
default=None, description="When adapter was paused, if paused"
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(description="Last configuration update time")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_paused(self) -> bool:
|
||||||
|
"""Check if adapter is currently paused."""
|
||||||
|
return self.paused_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyInfo(BaseModel):
|
||||||
|
"""Metadata about an API key (without the decrypted value)."""
|
||||||
|
|
||||||
|
alias: str = Field(description="Key identifier/alias")
|
||||||
|
created_at: datetime = Field(description="When key was created")
|
||||||
|
rotated_at: datetime | None = Field(
|
||||||
|
default=None, description="Last rotation time"
|
||||||
|
)
|
||||||
|
last_used_at: datetime | None = Field(
|
||||||
|
default=None, description="Last usage time"
|
||||||
|
)
|
||||||
269
src/central/config_store.py
Normal file
269
src/central/config_store.py
Normal file
|
|
@ -0,0 +1,269 @@
|
||||||
|
"""Database-backed configuration store.
|
||||||
|
|
||||||
|
Provides async access to the config schema tables with support for
|
||||||
|
Postgres LISTEN/NOTIFY for real-time config change notifications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from central.config_models import AdapterConfig
|
||||||
|
from central.crypto import decrypt, encrypt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
|
||||||
|
"""Set up JSON codec for asyncpg connection."""
|
||||||
|
await conn.set_type_codec(
|
||||||
|
"jsonb",
|
||||||
|
encoder=json.dumps,
|
||||||
|
decoder=json.loads,
|
||||||
|
schema="pg_catalog",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigStore:
|
||||||
|
"""Async interface to the config schema in Postgres."""
|
||||||
|
|
||||||
|
def __init__(self, pool: asyncpg.Pool) -> None:
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
|
||||||
|
"""Create a ConfigStore with a new connection pool."""
|
||||||
|
pool = await asyncpg.create_pool(
|
||||||
|
dsn,
|
||||||
|
min_size=min_size,
|
||||||
|
max_size=max_size,
|
||||||
|
init=_setup_json_codec,
|
||||||
|
)
|
||||||
|
return cls(pool)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the connection pool."""
|
||||||
|
await self._pool.close()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Adapter configuration
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||||
|
"""Get configuration for a specific adapter."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||||
|
FROM config.adapters
|
||||||
|
WHERE name = $1
|
||||||
|
""",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return AdapterConfig(**dict(row))
|
||||||
|
|
||||||
|
async def list_adapters(self) -> list[AdapterConfig]:
|
||||||
|
"""List all configured adapters."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||||
|
FROM config.adapters
|
||||||
|
ORDER BY name
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return [AdapterConfig(**dict(row)) for row in rows]
|
||||||
|
|
||||||
|
async def upsert_adapter(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
enabled: bool,
|
||||||
|
cadence_s: int,
|
||||||
|
settings: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update an adapter configuration."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, now())
|
||||||
|
ON CONFLICT (name) DO UPDATE SET
|
||||||
|
enabled = EXCLUDED.enabled,
|
||||||
|
cadence_s = EXCLUDED.cadence_s,
|
||||||
|
settings = EXCLUDED.settings,
|
||||||
|
updated_at = now()
|
||||||
|
""",
|
||||||
|
name,
|
||||||
|
enabled,
|
||||||
|
cadence_s,
|
||||||
|
settings, # Will be encoded as JSON by the codec
|
||||||
|
)
|
||||||
|
|
||||||
|
async def pause_adapter(self, name: str) -> None:
|
||||||
|
"""Pause an adapter by setting paused_at."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE config.adapters
|
||||||
|
SET paused_at = now(), updated_at = now()
|
||||||
|
WHERE name = $1
|
||||||
|
""",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unpause_adapter(self, name: str) -> None:
|
||||||
|
"""Unpause an adapter by clearing paused_at."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE config.adapters
|
||||||
|
SET paused_at = NULL, updated_at = now()
|
||||||
|
WHERE name = $1
|
||||||
|
""",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# API key management
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
|
||||||
|
"""Store an API key, encrypting it with the master key."""
|
||||||
|
encrypted = encrypt(plaintext_value.encode("utf-8"))
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO config.api_keys (alias, encrypted_value)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
ON CONFLICT (alias) DO UPDATE SET
|
||||||
|
encrypted_value = EXCLUDED.encrypted_value,
|
||||||
|
rotated_at = now()
|
||||||
|
""",
|
||||||
|
alias,
|
||||||
|
encrypted,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_api_key(self, alias: str) -> str | None:
|
||||||
|
"""Retrieve and decrypt an API key by alias."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""
|
||||||
|
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
|
||||||
|
""",
|
||||||
|
alias,
|
||||||
|
)
|
||||||
|
if row is not None:
|
||||||
|
# Update last_used_at
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
|
||||||
|
""",
|
||||||
|
alias,
|
||||||
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return decrypt(row["encrypted_value"]).decode("utf-8")
|
||||||
|
|
||||||
|
async def delete_api_key(self, alias: str) -> bool:
|
||||||
|
"""Delete an API key. Returns True if key existed."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
"DELETE FROM config.api_keys WHERE alias = $1", alias
|
||||||
|
)
|
||||||
|
return result == "DELETE 1"
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Change notifications
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def listen_for_changes(
|
||||||
|
self,
|
||||||
|
callback: Callable[[str, str], Awaitable[None] | None],
|
||||||
|
) -> None:
|
||||||
|
"""Listen for config changes via Postgres NOTIFY.
|
||||||
|
|
||||||
|
Runs forever, calling callback(table, key) each time a change is
|
||||||
|
detected. The callback can be sync or async.
|
||||||
|
|
||||||
|
On connection loss, automatically reconnects with exponential backoff.
|
||||||
|
Cancellation (via task.cancel()) propagates cleanly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function called with (table_name, row_key) on each change.
|
||||||
|
"""
|
||||||
|
backoff = 1.0
|
||||||
|
max_backoff = 30.0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = await self._pool.acquire()
|
||||||
|
logger.info("Config listener connected to database")
|
||||||
|
backoff = 1.0 # Reset backoff on successful connect
|
||||||
|
|
||||||
|
def notification_handler(
|
||||||
|
conn: asyncpg.Connection,
|
||||||
|
pid: int,
|
||||||
|
channel: str,
|
||||||
|
payload: str,
|
||||||
|
) -> None:
|
||||||
|
# payload format: "table_name:key"
|
||||||
|
if ":" in payload:
|
||||||
|
table, key = payload.split(":", 1)
|
||||||
|
else:
|
||||||
|
table, key = payload, ""
|
||||||
|
|
||||||
|
result = callback(table, key)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
asyncio.create_task(result)
|
||||||
|
|
||||||
|
await conn.add_listener("config_changed", notification_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Keep connection alive with periodic keepalive
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
await conn.execute("SELECT 1")
|
||||||
|
finally:
|
||||||
|
await conn.remove_listener("config_changed", notification_handler)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Cancellation must propagate cleanly
|
||||||
|
logger.info("Config listener cancelled")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except (
|
||||||
|
asyncpg.PostgresConnectionError,
|
||||||
|
asyncpg.InterfaceError,
|
||||||
|
ConnectionResetError,
|
||||||
|
OSError,
|
||||||
|
) as e:
|
||||||
|
logger.warning(
|
||||||
|
"Config listener connection lost, reconnecting in %.1fs: %s",
|
||||||
|
backoff,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
backoff = min(backoff * 2, max_backoff)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Unexpected error - log and retry with backoff
|
||||||
|
logger.exception(
|
||||||
|
"Config listener unexpected error, reconnecting in %.1fs",
|
||||||
|
backoff,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
backoff = min(backoff * 2, max_backoff)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if conn is not None:
|
||||||
|
try:
|
||||||
|
await self._pool.release(conn)
|
||||||
|
except Exception:
|
||||||
|
pass # Connection may already be invalid
|
||||||
111
src/central/crypto.py
Normal file
111
src/central/crypto.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Cryptographic primitives for secret storage.
|
||||||
|
|
||||||
|
Uses AES-256-GCM for authenticated encryption. The master key is read
|
||||||
|
from the path specified in bootstrap config on first use and cached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||||
|
|
||||||
|
# AES-256 requires 32-byte key
|
||||||
|
KEY_SIZE = 32
|
||||||
|
# GCM nonce size (96 bits recommended by NIST)
|
||||||
|
NONCE_SIZE = 12
|
||||||
|
|
||||||
|
|
||||||
|
class CryptoError(Exception):
|
||||||
|
"""Base exception for crypto operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class KeyLoadError(CryptoError):
|
||||||
|
"""Failed to load master key."""
|
||||||
|
|
||||||
|
|
||||||
|
class DecryptionError(CryptoError):
|
||||||
|
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _load_master_key(path: Path) -> bytes:
|
||||||
|
"""Load and decode the base64-encoded master key from file."""
|
||||||
|
try:
|
||||||
|
key_b64 = path.read_text().strip()
|
||||||
|
key = base64.b64decode(key_b64)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise KeyLoadError(f"Master key file not found: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
|
||||||
|
|
||||||
|
if len(key) != KEY_SIZE:
|
||||||
|
raise KeyLoadError(
|
||||||
|
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
|
||||||
|
"""Encrypt plaintext using AES-256-GCM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plaintext: Data to encrypt.
|
||||||
|
key_path: Path to master key file. If None, uses default from
|
||||||
|
bootstrap config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
|
||||||
|
"""
|
||||||
|
if key_path is None:
|
||||||
|
from central.bootstrap_config import get_settings
|
||||||
|
key_path = get_settings().master_key_path
|
||||||
|
|
||||||
|
key = _load_master_key(key_path)
|
||||||
|
nonce = os.urandom(NONCE_SIZE)
|
||||||
|
aesgcm = AESGCM(key)
|
||||||
|
|
||||||
|
# GCM appends the 16-byte tag to the ciphertext
|
||||||
|
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
|
||||||
|
|
||||||
|
return nonce + ciphertext_with_tag
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
|
||||||
|
"""Decrypt ciphertext using AES-256-GCM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ciphertext: Data in format: nonce || ciphertext || tag
|
||||||
|
key_path: Path to master key file. If None, uses default from
|
||||||
|
bootstrap config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decrypted plaintext.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DecryptionError: If decryption fails (wrong key or tampered data).
|
||||||
|
"""
|
||||||
|
if key_path is None:
|
||||||
|
from central.bootstrap_config import get_settings
|
||||||
|
key_path = get_settings().master_key_path
|
||||||
|
|
||||||
|
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
|
||||||
|
raise DecryptionError("Ciphertext too short")
|
||||||
|
|
||||||
|
key = _load_master_key(key_path)
|
||||||
|
nonce = ciphertext[:NONCE_SIZE]
|
||||||
|
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
|
||||||
|
|
||||||
|
aesgcm = AESGCM(key)
|
||||||
|
try:
|
||||||
|
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
|
||||||
|
except Exception as e:
|
||||||
|
raise DecryptionError(f"Decryption failed: {e}")
|
||||||
|
|
||||||
|
return plaintext
|
||||||
|
|
||||||
|
|
||||||
|
def clear_key_cache() -> None:
|
||||||
|
"""Clear the cached master key. Use after key rotation."""
|
||||||
|
_load_master_key.cache_clear()
|
||||||
125
src/central/migrate.py
Normal file
125
src/central/migrate.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Simple database migration runner.
|
||||||
|
|
||||||
|
Tracks applied migrations in a `schema_migrations` table. Migrations are
|
||||||
|
plain SQL files in `sql/migrations/` named with numeric prefixes:
|
||||||
|
001_create_config_schema.sql
|
||||||
|
002_add_operators_table.sql
|
||||||
|
...
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
central-migrate [--dry-run]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
|
||||||
|
"""Create the schema_migrations table if it doesn't exist."""
|
||||||
|
await conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||||
|
version TEXT PRIMARY KEY,
|
||||||
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
|
||||||
|
"""Return set of already-applied migration versions."""
|
||||||
|
rows = await conn.fetch("SELECT version FROM schema_migrations")
|
||||||
|
return {row["version"] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
|
||||||
|
"""Find all .sql files in migrations directory, sorted by name.
|
||||||
|
|
||||||
|
Returns list of (version, path) tuples where version is the filename
|
||||||
|
without extension.
|
||||||
|
"""
|
||||||
|
if not migrations_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
migrations = []
|
||||||
|
for f in sorted(migrations_dir.glob("*.sql")):
|
||||||
|
version = f.stem # e.g., "001_create_config_schema"
|
||||||
|
migrations.append((version, f))
|
||||||
|
return migrations
|
||||||
|
|
||||||
|
|
||||||
|
async def apply_migration(
|
||||||
|
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Apply a single migration."""
|
||||||
|
sql = sql_path.read_text()
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print(f"[DRY RUN] Would apply: {version}")
|
||||||
|
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with conn.transaction():
|
||||||
|
await conn.execute(sql)
|
||||||
|
await conn.execute(
|
||||||
|
"INSERT INTO schema_migrations (version) VALUES ($1)", version
|
||||||
|
)
|
||||||
|
print(f"Applied: {version}")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
|
||||||
|
"""Run all pending migrations.
|
||||||
|
|
||||||
|
Returns number of migrations applied.
|
||||||
|
"""
|
||||||
|
conn = await asyncpg.connect(dsn)
|
||||||
|
try:
|
||||||
|
await ensure_migrations_table(conn)
|
||||||
|
applied = await get_applied_migrations(conn)
|
||||||
|
pending = [
|
||||||
|
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
|
||||||
|
]
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
print("No pending migrations.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print(f"Found {len(pending)} pending migration(s).")
|
||||||
|
for version, path in pending:
|
||||||
|
await apply_migration(conn, version, path, dry_run)
|
||||||
|
|
||||||
|
return len(pending)
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_main() -> None:
|
||||||
|
"""Async entry point."""
|
||||||
|
parser = argparse.ArgumentParser(description="Run database migrations")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Show what would be applied without executing",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
from central.bootstrap_config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
|
||||||
|
|
||||||
|
if count > 0 and not args.dry_run:
|
||||||
|
print(f"Successfully applied {count} migration(s).")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point."""
|
||||||
|
asyncio.run(async_main())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
18
tests/README.md
Normal file
18
tests/README.md
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Central Tests
|
||||||
|
|
||||||
|
## Test Database
|
||||||
|
|
||||||
|
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
|
||||||
|
By default, tests connect to:
|
||||||
|
|
||||||
|
```
|
||||||
|
postgresql://central_test:testpass@localhost/central_test
|
||||||
|
```
|
||||||
|
|
||||||
|
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
|
||||||
|
environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
|
||||||
|
uv run pytest tests/test_config_store.py
|
||||||
|
```
|
||||||
123
tests/test_bootstrap_config.py
Normal file
123
tests/test_bootstrap_config.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
"""Tests for bootstrap configuration."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from central.bootstrap_config import Settings, get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsFromEnv:
|
||||||
|
"""Test loading settings from environment variables."""
|
||||||
|
|
||||||
|
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Settings are read from CENTRAL_* environment variables."""
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
|
||||||
|
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
|
||||||
|
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
|
||||||
|
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
|
||||||
|
assert settings.nats_url == "nats://10.0.0.1:4222"
|
||||||
|
assert settings.master_key_path == Path("/tmp/test.key")
|
||||||
|
assert settings.log_level == "DEBUG"
|
||||||
|
|
||||||
|
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Default values are used when env vars not set."""
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
|
||||||
|
# Clear any existing env vars that might interfere
|
||||||
|
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
|
||||||
|
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
|
||||||
|
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
assert settings.nats_url == "nats://localhost:4222"
|
||||||
|
assert settings.master_key_path == Path("/etc/central/master.key")
|
||||||
|
assert settings.log_level == "INFO"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsFromFile:
|
||||||
|
"""Test loading settings from .env file."""
|
||||||
|
|
||||||
|
def test_reads_from_env_file(self, tmp_path: Path) -> None:
|
||||||
|
"""Settings are read from .env file when env vars not present."""
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
env_file.write_text(
|
||||||
|
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
|
||||||
|
"CENTRAL_NATS_URL=nats://file.local:4222\n"
|
||||||
|
"CENTRAL_LOG_LEVEL=WARNING\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create settings pointing to the temp .env file
|
||||||
|
settings = Settings(_env_file=env_file)
|
||||||
|
|
||||||
|
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
|
||||||
|
assert settings.nats_url == "nats://file.local:4222"
|
||||||
|
assert settings.log_level == "WARNING"
|
||||||
|
|
||||||
|
def test_env_vars_override_file(
|
||||||
|
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Environment variables take precedence over .env file."""
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
|
||||||
|
|
||||||
|
settings = Settings(_env_file=env_file)
|
||||||
|
|
||||||
|
assert settings.db_dsn == "postgresql://env@localhost/envdb"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsValidation:
|
||||||
|
"""Test settings validation and error handling."""
|
||||||
|
|
||||||
|
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Clear error when required CENTRAL_DB_DSN is missing."""
|
||||||
|
# Ensure no env vars or .env file provides the DSN
|
||||||
|
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
# Use a non-existent .env file path to ensure no fallback
|
||||||
|
Settings(_env_file=Path("/nonexistent/.env"))
|
||||||
|
|
||||||
|
# pydantic-settings raises ValidationError for missing required fields
|
||||||
|
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Invalid log level values are rejected."""
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
|
||||||
|
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
Settings()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSettings:
|
||||||
|
"""Test the cached settings loader."""
|
||||||
|
|
||||||
|
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""get_settings() returns cached instance."""
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
s1 = get_settings()
|
||||||
|
s2 = get_settings()
|
||||||
|
|
||||||
|
assert s1 is s2
|
||||||
|
|
||||||
|
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""cache_clear() forces reload on next call."""
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
|
||||||
|
get_settings.cache_clear()
|
||||||
|
s1 = get_settings()
|
||||||
|
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
|
||||||
|
get_settings.cache_clear()
|
||||||
|
s2 = get_settings()
|
||||||
|
|
||||||
|
assert s1.db_dsn != s2.db_dsn
|
||||||
339
tests/test_config_store.py
Normal file
339
tests/test_config_store.py
Normal file
|
|
@ -0,0 +1,339 @@
|
||||||
|
"""Tests for database-backed configuration store.
|
||||||
|
|
||||||
|
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
|
||||||
|
environment variable to override the default test database connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from central.config_store import ConfigStore
|
||||||
|
from central.crypto import KEY_SIZE, clear_key_cache
|
||||||
|
|
||||||
|
# Test database DSN - uses central_test database with well-known test password.
|
||||||
|
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
|
||||||
|
TEST_DB_DSN = os.environ.get(
|
||||||
|
"CENTRAL_TEST_DB_DSN",
|
||||||
|
"postgresql://central_test:testpass@localhost/central_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||||
|
"""Create a master key file for the test session."""
|
||||||
|
key = os.urandom(KEY_SIZE)
|
||||||
|
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||||
|
key_path.write_text(base64.b64encode(key).decode())
|
||||||
|
return key_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Configure master key path for all tests."""
|
||||||
|
clear_key_cache()
|
||||||
|
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||||
|
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_conn() -> asyncpg.Connection:
|
||||||
|
"""Get a direct database connection for setup/teardown."""
|
||||||
|
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||||
|
yield conn
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||||
|
"""Ensure config schema exists and is clean before each test."""
|
||||||
|
# Create schema if not exists
|
||||||
|
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||||
|
|
||||||
|
# Create tables if not exist
|
||||||
|
await db_conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
cadence_s INTEGER NOT NULL,
|
||||||
|
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
paused_at TIMESTAMPTZ,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
await db_conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS config.api_keys (
|
||||||
|
alias TEXT PRIMARY KEY,
|
||||||
|
encrypted_value BYTEA NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
rotated_at TIMESTAMPTZ,
|
||||||
|
last_used_at TIMESTAMPTZ
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create notify function with proper key detection
|
||||||
|
await db_conn.execute("""
|
||||||
|
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||||
|
RETURNS trigger AS $$
|
||||||
|
DECLARE
|
||||||
|
key_value TEXT;
|
||||||
|
BEGIN
|
||||||
|
IF TG_TABLE_NAME = 'adapters' THEN
|
||||||
|
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||||
|
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||||
|
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||||
|
ELSE
|
||||||
|
key_value := '';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||||
|
RETURN COALESCE(NEW, OLD);
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create triggers if not exist
|
||||||
|
await db_conn.execute("""
|
||||||
|
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||||
|
CREATE TRIGGER adapters_notify
|
||||||
|
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||||
|
""")
|
||||||
|
await db_conn.execute("""
|
||||||
|
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
|
||||||
|
CREATE TRIGGER api_keys_notify
|
||||||
|
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Clean tables
|
||||||
|
await db_conn.execute("DELETE FROM config.adapters")
|
||||||
|
await db_conn.execute("DELETE FROM config.api_keys")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||||
|
"""Create a ConfigStore connected to the test database."""
|
||||||
|
store = await ConfigStore.create(TEST_DB_DSN)
|
||||||
|
yield store
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdapterConfig:
|
||||||
|
"""Tests for adapter configuration operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Can insert and retrieve adapter config."""
|
||||||
|
await config_store.upsert_adapter(
|
||||||
|
name="test_adapter",
|
||||||
|
enabled=True,
|
||||||
|
cadence_s=120,
|
||||||
|
settings={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = await config_store.get_adapter("test_adapter")
|
||||||
|
|
||||||
|
assert adapter is not None
|
||||||
|
assert adapter.name == "test_adapter"
|
||||||
|
assert adapter.enabled is True
|
||||||
|
assert adapter.cadence_s == 120
|
||||||
|
assert adapter.settings == {"key": "value"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Getting nonexistent adapter returns None."""
|
||||||
|
adapter = await config_store.get_adapter("does_not_exist")
|
||||||
|
assert adapter is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_adapters(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Can list all adapters."""
|
||||||
|
await config_store.upsert_adapter("adapter_a", True, 60, {})
|
||||||
|
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
|
||||||
|
|
||||||
|
adapters = await config_store.list_adapters()
|
||||||
|
|
||||||
|
assert len(adapters) == 2
|
||||||
|
names = [a.name for a in adapters]
|
||||||
|
assert "adapter_a" in names
|
||||||
|
assert "adapter_b" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Upsert updates existing adapter."""
|
||||||
|
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
|
||||||
|
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
|
||||||
|
|
||||||
|
adapter = await config_store.get_adapter("updater")
|
||||||
|
|
||||||
|
assert adapter is not None
|
||||||
|
assert adapter.enabled is False
|
||||||
|
assert adapter.cadence_s == 120
|
||||||
|
assert adapter.settings == {"v": 2}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Can pause and unpause adapter."""
|
||||||
|
await config_store.upsert_adapter("pausable", True, 60, {})
|
||||||
|
|
||||||
|
await config_store.pause_adapter("pausable")
|
||||||
|
adapter = await config_store.get_adapter("pausable")
|
||||||
|
assert adapter is not None
|
||||||
|
assert adapter.is_paused is True
|
||||||
|
|
||||||
|
await config_store.unpause_adapter("pausable")
|
||||||
|
adapter = await config_store.get_adapter("pausable")
|
||||||
|
assert adapter is not None
|
||||||
|
assert adapter.is_paused is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiKeys:
|
||||||
|
"""Tests for API key operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Can store and retrieve encrypted API key."""
|
||||||
|
await config_store.set_api_key("test_key", "super_secret_value")
|
||||||
|
|
||||||
|
value = await config_store.get_api_key("test_key")
|
||||||
|
|
||||||
|
assert value == "super_secret_value"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Getting nonexistent key returns None."""
|
||||||
|
value = await config_store.get_api_key("does_not_exist")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_rotation(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Updating key sets rotated_at."""
|
||||||
|
await config_store.set_api_key("rotate_me", "value1")
|
||||||
|
await config_store.set_api_key("rotate_me", "value2")
|
||||||
|
|
||||||
|
value = await config_store.get_api_key("rotate_me")
|
||||||
|
assert value == "value2"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_key(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Can delete API key."""
|
||||||
|
await config_store.set_api_key("delete_me", "value")
|
||||||
|
|
||||||
|
deleted = await config_store.delete_api_key("delete_me")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
value = await config_store.get_api_key("delete_me")
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
|
||||||
|
"""Deleting nonexistent key returns False."""
|
||||||
|
deleted = await config_store.delete_api_key("never_existed")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotifications:
|
||||||
|
"""Tests for LISTEN/NOTIFY functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
|
||||||
|
"""NOTIFY fires when adapter is changed."""
|
||||||
|
notifications: list[tuple[str, str]] = []
|
||||||
|
notification_received = asyncio.Event()
|
||||||
|
|
||||||
|
async def callback(table: str, key: str) -> None:
|
||||||
|
notifications.append((table, key))
|
||||||
|
notification_received.set()
|
||||||
|
|
||||||
|
# Start listener in background
|
||||||
|
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Give listener time to subscribe
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Trigger a change
|
||||||
|
await config_store.upsert_adapter("notify_test", True, 60, {})
|
||||||
|
|
||||||
|
# Wait for notification (with timeout)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("Notification not received within timeout")
|
||||||
|
|
||||||
|
assert len(notifications) >= 1
|
||||||
|
assert notifications[0][0] == "adapters"
|
||||||
|
assert notifications[0][1] == "notify_test"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
listen_task.cancel()
|
||||||
|
try:
|
||||||
|
await listen_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
|
||||||
|
"""NOTIFY fires when API key is changed."""
|
||||||
|
notifications: list[tuple[str, str]] = []
|
||||||
|
notification_received = asyncio.Event()
|
||||||
|
|
||||||
|
async def callback(table: str, key: str) -> None:
|
||||||
|
notifications.append((table, key))
|
||||||
|
notification_received.set()
|
||||||
|
|
||||||
|
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await config_store.set_api_key("notify_key", "secret")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("Notification not received within timeout")
|
||||||
|
|
||||||
|
assert len(notifications) >= 1
|
||||||
|
assert notifications[0][0] == "api_keys"
|
||||||
|
assert notifications[0][1] == "notify_key"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
listen_task.cancel()
|
||||||
|
try:
|
||||||
|
await listen_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestListenerReconnect:
|
||||||
|
"""Tests for listener reconnection on connection loss."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_listener_cancellation_propagates(
|
||||||
|
self, config_store: ConfigStore
|
||||||
|
) -> None:
|
||||||
|
"""Cancellation cleanly stops the listener without reconnect loop."""
|
||||||
|
async def callback(table: str, key: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||||
|
|
||||||
|
# Give listener time to start
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Cancel and verify it stops
|
||||||
|
listen_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(listen_task, timeout=2.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass # Expected
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("Listener did not stop after cancellation")
|
||||||
|
|
||||||
|
assert listen_task.cancelled() or listen_task.done()
|
||||||
175
tests/test_crypto.py
Normal file
175
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""Tests for cryptographic primitives."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from central.crypto import (
|
||||||
|
KEY_SIZE,
|
||||||
|
DecryptionError,
|
||||||
|
KeyLoadError,
|
||||||
|
clear_key_cache,
|
||||||
|
decrypt,
|
||||||
|
encrypt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def master_key(tmp_path: Path) -> Path:
|
||||||
|
"""Create a valid master key file."""
|
||||||
|
key = os.urandom(KEY_SIZE)
|
||||||
|
key_path = tmp_path / "master.key"
|
||||||
|
key_path.write_text(base64.b64encode(key).decode())
|
||||||
|
clear_key_cache()
|
||||||
|
return key_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def wrong_key(tmp_path: Path) -> Path:
|
||||||
|
"""Create a different master key file."""
|
||||||
|
key = os.urandom(KEY_SIZE)
|
||||||
|
key_path = tmp_path / "wrong.key"
|
||||||
|
key_path.write_text(base64.b64encode(key).decode())
|
||||||
|
return key_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncryptDecrypt:
|
||||||
|
"""Test encrypt/decrypt round-trip."""
|
||||||
|
|
||||||
|
def test_round_trip(self, master_key: Path) -> None:
|
||||||
|
"""Encrypting then decrypting returns original plaintext."""
|
||||||
|
plaintext = b"Hello, Central!"
|
||||||
|
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||||
|
|
||||||
|
assert decrypted == plaintext
|
||||||
|
|
||||||
|
def test_round_trip_empty(self, master_key: Path) -> None:
|
||||||
|
"""Empty plaintext encrypts and decrypts correctly."""
|
||||||
|
plaintext = b""
|
||||||
|
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||||
|
|
||||||
|
assert decrypted == plaintext
|
||||||
|
|
||||||
|
def test_round_trip_large(self, master_key: Path) -> None:
|
||||||
|
"""Large plaintext encrypts and decrypts correctly."""
|
||||||
|
plaintext = os.urandom(1024 * 1024) # 1MB
|
||||||
|
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||||
|
|
||||||
|
assert decrypted == plaintext
|
||||||
|
|
||||||
|
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
|
||||||
|
"""Same plaintext produces different ciphertext (random nonce)."""
|
||||||
|
plaintext = b"test"
|
||||||
|
|
||||||
|
ct1 = encrypt(plaintext, key_path=master_key)
|
||||||
|
ct2 = encrypt(plaintext, key_path=master_key)
|
||||||
|
|
||||||
|
assert ct1 != ct2
|
||||||
|
# But both decrypt to same plaintext
|
||||||
|
assert decrypt(ct1, key_path=master_key) == plaintext
|
||||||
|
assert decrypt(ct2, key_path=master_key) == plaintext
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecryptionFailures:
|
||||||
|
"""Test AEAD authentication catches tampering."""
|
||||||
|
|
||||||
|
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
|
||||||
|
"""Decryption with wrong key raises DecryptionError."""
|
||||||
|
plaintext = b"secret"
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
|
||||||
|
clear_key_cache() # Clear cache so wrong_key is loaded
|
||||||
|
with pytest.raises(DecryptionError):
|
||||||
|
decrypt(ciphertext, key_path=wrong_key)
|
||||||
|
|
||||||
|
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
|
||||||
|
"""Modified ciphertext is detected and rejected."""
|
||||||
|
plaintext = b"secret"
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
|
||||||
|
# Flip a bit in the ciphertext (after nonce, before tag)
|
||||||
|
tampered = bytearray(ciphertext)
|
||||||
|
tampered[15] ^= 0x01 # Flip one bit
|
||||||
|
tampered = bytes(tampered)
|
||||||
|
|
||||||
|
with pytest.raises(DecryptionError):
|
||||||
|
decrypt(tampered, key_path=master_key)
|
||||||
|
|
||||||
|
def test_tampered_tag_fails(self, master_key: Path) -> None:
|
||||||
|
"""Modified authentication tag is detected and rejected."""
|
||||||
|
plaintext = b"secret"
|
||||||
|
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||||
|
|
||||||
|
# Flip a bit in the last byte (part of the tag)
|
||||||
|
tampered = bytearray(ciphertext)
|
||||||
|
tampered[-1] ^= 0x01
|
||||||
|
tampered = bytes(tampered)
|
||||||
|
|
||||||
|
with pytest.raises(DecryptionError):
|
||||||
|
decrypt(tampered, key_path=master_key)
|
||||||
|
|
||||||
|
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
|
||||||
|
"""Truncated ciphertext is rejected."""
|
||||||
|
ciphertext = b"tooshort"
|
||||||
|
|
||||||
|
with pytest.raises(DecryptionError, match="too short"):
|
||||||
|
decrypt(ciphertext, key_path=master_key)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyLoading:
|
||||||
|
"""Test master key loading."""
|
||||||
|
|
||||||
|
def test_missing_key_file(self, tmp_path: Path) -> None:
|
||||||
|
"""Missing key file raises KeyLoadError."""
|
||||||
|
clear_key_cache()
|
||||||
|
missing = tmp_path / "nonexistent.key"
|
||||||
|
|
||||||
|
with pytest.raises(KeyLoadError, match="not found"):
|
||||||
|
encrypt(b"test", key_path=missing)
|
||||||
|
|
||||||
|
def test_invalid_key_size(self, tmp_path: Path) -> None:
|
||||||
|
"""Key file with wrong size raises KeyLoadError."""
|
||||||
|
clear_key_cache()
|
||||||
|
bad_key = tmp_path / "bad.key"
|
||||||
|
bad_key.write_text(base64.b64encode(b"tooshort").decode())
|
||||||
|
|
||||||
|
with pytest.raises(KeyLoadError, match="Invalid master key size"):
|
||||||
|
encrypt(b"test", key_path=bad_key)
|
||||||
|
|
||||||
|
def test_invalid_base64(self, tmp_path: Path) -> None:
|
||||||
|
"""Invalid base64 in key file raises KeyLoadError."""
|
||||||
|
clear_key_cache()
|
||||||
|
bad_key = tmp_path / "bad.key"
|
||||||
|
bad_key.write_text("not valid base64!!!")
|
||||||
|
|
||||||
|
with pytest.raises(KeyLoadError):
|
||||||
|
encrypt(b"test", key_path=bad_key)
|
||||||
|
|
||||||
|
def test_key_cached(self, master_key: Path) -> None:
|
||||||
|
"""Key is cached after first load."""
|
||||||
|
# First encryption loads the key
|
||||||
|
encrypt(b"test1", key_path=master_key)
|
||||||
|
|
||||||
|
# Delete the file
|
||||||
|
master_key.unlink()
|
||||||
|
|
||||||
|
# Second encryption should still work (cached)
|
||||||
|
ciphertext = encrypt(b"test2", key_path=master_key)
|
||||||
|
assert len(ciphertext) > 0
|
||||||
|
|
||||||
|
def test_cache_clear(self, master_key: Path) -> None:
|
||||||
|
"""clear_key_cache forces reload."""
|
||||||
|
encrypt(b"test", key_path=master_key)
|
||||||
|
master_key.unlink()
|
||||||
|
clear_key_cache()
|
||||||
|
|
||||||
|
with pytest.raises(KeyLoadError, match="not found"):
|
||||||
|
encrypt(b"test", key_path=master_key)
|
||||||
|
|
@ -49,7 +49,7 @@ def sample_config() -> Config:
|
||||||
},
|
},
|
||||||
cloudevents=CloudEventsConfig(
|
cloudevents=CloudEventsConfig(
|
||||||
type_prefix="central",
|
type_prefix="central",
|
||||||
source="central.echo6.mesh",
|
source="central.local",
|
||||||
schema_version="1.0",
|
schema_version="1.0",
|
||||||
),
|
),
|
||||||
nats=NATSConfig(url="nats://localhost:4222"),
|
nats=NATSConfig(url="nats://localhost:4222"),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue