mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
Merge pull request #1 from zvx-echo6/feature/1a-config-storage
feat(config): Phase 1a-2 config storage primitives
This commit is contained in:
commit
ee081c9bc2
14 changed files with 1570 additions and 161 deletions
|
|
@ -15,14 +15,18 @@ dependencies = [
|
|||
"aiolimiter>=1.2.1",
|
||||
"asyncpg>=0.31.0",
|
||||
"cloudevents>=2.0.0",
|
||||
"cryptography>=44.0.0",
|
||||
"nats-py>=2.14.0",
|
||||
"pydantic>=2,<3",
|
||||
"pydantic-settings>=2.7.0",
|
||||
"tenacity>=9.1.4",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
central-supervisor = "central.supervisor:main"
|
||||
central-archive = "central.archive:main"
|
||||
central-migrate = "central.migrate:main"
|
||||
central-cli = "central.cli:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/central"]
|
||||
|
|
|
|||
64
sql/migrations/001_create_config_schema.sql
Normal file
64
sql/migrations/001_create_config_schema.sql
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
-- Migration: 001_create_config_schema
|
||||
-- Creates the config schema with adapters and api_keys tables.
|
||||
-- Also seeds the NWS adapter row from current TOML config.
|
||||
|
||||
-- Create config schema
|
||||
CREATE SCHEMA config;
|
||||
|
||||
-- Adapters configuration table
|
||||
CREATE TABLE config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
-- API keys table (encrypted values)
|
||||
CREATE TABLE config.api_keys (
|
||||
alias TEXT PRIMARY KEY,
|
||||
encrypted_value BYTEA NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
rotated_at TIMESTAMPTZ,
|
||||
last_used_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Notify function for config changes
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
-- Handle different table structures
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Trigger for adapters table
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
|
||||
|
||||
-- Trigger for api_keys table
|
||||
CREATE TRIGGER api_keys_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change();
|
||||
|
||||
-- Seed NWS adapter from current TOML config values
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES (
|
||||
'nws',
|
||||
true,
|
||||
60,
|
||||
'{"states": ["ID", "OR", "WA", "MT", "WY", "UT", "NV"], "contact_email": "mj@k7zvx.com"}'::jsonb
|
||||
);
|
||||
21
sql/migrations/002_add_updated_at_trigger_and_index.sql
Normal file
21
sql/migrations/002_add_updated_at_trigger_and_index.sql
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
-- Migration: 002_add_updated_at_trigger_and_index
|
||||
-- Adds auto-update trigger for updated_at column on adapters table
|
||||
-- Adds partial index for efficient enabled adapter queries
|
||||
|
||||
-- Auto-update trigger for updated_at
|
||||
CREATE OR REPLACE FUNCTION config.set_updated_at()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
NEW.updated_at := now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER adapters_set_updated_at
|
||||
BEFORE UPDATE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.set_updated_at();
|
||||
|
||||
-- Partial index for enabled adapters (common query pattern)
|
||||
CREATE INDEX adapters_enabled_idx
|
||||
ON config.adapters (enabled)
|
||||
WHERE enabled = true;
|
||||
46
src/central/bootstrap_config.py
Normal file
46
src/central/bootstrap_config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Bootstrap configuration from environment variables.
|
||||
|
||||
This module provides early-stage configuration loading from environment
|
||||
variables or a .env file. Used before the database-backed config store
|
||||
is available.
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Bootstrap settings loaded from environment or .env file."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="CENTRAL_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
db_dsn: str = Field(description="PostgreSQL connection string")
|
||||
nats_url: str = Field(default="nats://localhost:4222", description="NATS server URL")
|
||||
master_key_path: Path = Field(
|
||||
default=Path("/etc/central/master.key"),
|
||||
description="Path to AES-256 master key file",
|
||||
)
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||
default="INFO",
|
||||
description="Logging level",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings(env_file: Path | None = None) -> Settings:
|
||||
"""Load settings, optionally from a specific .env file.
|
||||
|
||||
Results are cached. Call get_settings.cache_clear() to reload.
|
||||
"""
|
||||
if env_file is not None:
|
||||
return Settings(_env_file=env_file)
|
||||
return Settings()
|
||||
75
src/central/cli.py
Normal file
75
src/central/cli.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Central CLI commands."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
|
||||
async def config_store_check() -> int:
|
||||
"""Smoke test for config store connectivity.
|
||||
|
||||
Connects via bootstrap_config, lists adapters, and verifies crypto.
|
||||
Returns 0 on success, 1 on failure.
|
||||
"""
|
||||
from central.bootstrap_config import get_settings
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
settings = get_settings()
|
||||
print(f"Connecting to: {settings.db_dsn.split('@')[1]}") # Hide password
|
||||
|
||||
try:
|
||||
store = await ConfigStore.create(settings.db_dsn)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to connect to database: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# List adapters
|
||||
adapters = await store.list_adapters()
|
||||
print(f"\nAdapters ({len(adapters)}):")
|
||||
for adapter in adapters:
|
||||
print(f" - {adapter.name}: enabled={adapter.enabled}, cadence_s={adapter.cadence_s}")
|
||||
print(f" settings: {adapter.settings}")
|
||||
|
||||
# Test crypto
|
||||
test_plaintext = b"config_store_check_test"
|
||||
try:
|
||||
ciphertext = encrypt(test_plaintext)
|
||||
decrypted = decrypt(ciphertext)
|
||||
if decrypted == test_plaintext:
|
||||
print("\ncrypto: ok")
|
||||
else:
|
||||
print("\ncrypto: FAILED (round-trip mismatch)")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\ncrypto: FAILED ({e})")
|
||||
return 1
|
||||
|
||||
print("\nAll checks passed.")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
await store.close()
|
||||
|
||||
|
||||
def main_config_store_check() -> None:
|
||||
"""Entry point for central-cli config-store-check."""
|
||||
sys.exit(asyncio.run(config_store_check()))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(description="Central CLI")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("config-store-check", help="Test config store connectivity")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "config-store-check":
|
||||
main_config_store_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
src/central/config_models.py
Normal file
39
src/central/config_models.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Pydantic models for database-backed configuration."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AdapterConfig(BaseModel):
|
||||
"""Configuration for a single adapter."""
|
||||
|
||||
name: str = Field(description="Unique adapter identifier")
|
||||
enabled: bool = Field(default=True, description="Whether adapter is active")
|
||||
cadence_s: int = Field(description="Poll interval in seconds")
|
||||
settings: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Adapter-specific settings"
|
||||
)
|
||||
paused_at: datetime | None = Field(
|
||||
default=None, description="When adapter was paused, if paused"
|
||||
)
|
||||
updated_at: datetime = Field(description="Last configuration update time")
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if adapter is currently paused."""
|
||||
return self.paused_at is not None
|
||||
|
||||
|
||||
class ApiKeyInfo(BaseModel):
|
||||
"""Metadata about an API key (without the decrypted value)."""
|
||||
|
||||
alias: str = Field(description="Key identifier/alias")
|
||||
created_at: datetime = Field(description="When key was created")
|
||||
rotated_at: datetime | None = Field(
|
||||
default=None, description="Last rotation time"
|
||||
)
|
||||
last_used_at: datetime | None = Field(
|
||||
default=None, description="Last usage time"
|
||||
)
|
||||
269
src/central/config_store.py
Normal file
269
src/central/config_store.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
"""Database-backed configuration store.
|
||||
|
||||
Provides async access to the config schema tables with support for
|
||||
Postgres LISTEN/NOTIFY for real-time config change notifications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from central.config_models import AdapterConfig
|
||||
from central.crypto import decrypt, encrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _setup_json_codec(conn: asyncpg.Connection) -> None:
|
||||
"""Set up JSON codec for asyncpg connection."""
|
||||
await conn.set_type_codec(
|
||||
"jsonb",
|
||||
encoder=json.dumps,
|
||||
decoder=json.loads,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
|
||||
class ConfigStore:
|
||||
"""Async interface to the config schema in Postgres."""
|
||||
|
||||
def __init__(self, pool: asyncpg.Pool) -> None:
|
||||
self._pool = pool
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dsn: str, min_size: int = 1, max_size: int = 5) -> "ConfigStore":
|
||||
"""Create a ConfigStore with a new connection pool."""
|
||||
pool = await asyncpg.create_pool(
|
||||
dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
init=_setup_json_codec,
|
||||
)
|
||||
return cls(pool)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection pool."""
|
||||
await self._pool.close()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Adapter configuration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_adapter(self, name: str) -> AdapterConfig | None:
|
||||
"""Get configuration for a specific adapter."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return AdapterConfig(**dict(row))
|
||||
|
||||
async def list_adapters(self) -> list[AdapterConfig]:
|
||||
"""List all configured adapters."""
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
return [AdapterConfig(**dict(row)) for row in rows]
|
||||
|
||||
async def upsert_adapter(
|
||||
self,
|
||||
name: str,
|
||||
enabled: bool,
|
||||
cadence_s: int,
|
||||
settings: dict[str, Any],
|
||||
) -> None:
|
||||
"""Insert or update an adapter configuration."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings, updated_at)
|
||||
VALUES ($1, $2, $3, $4, now())
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
enabled = EXCLUDED.enabled,
|
||||
cadence_s = EXCLUDED.cadence_s,
|
||||
settings = EXCLUDED.settings,
|
||||
updated_at = now()
|
||||
""",
|
||||
name,
|
||||
enabled,
|
||||
cadence_s,
|
||||
settings, # Will be encoded as JSON by the codec
|
||||
)
|
||||
|
||||
async def pause_adapter(self, name: str) -> None:
|
||||
"""Pause an adapter by setting paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now(), updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
async def unpause_adapter(self, name: str) -> None:
|
||||
"""Unpause an adapter by clearing paused_at."""
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = NULL, updated_at = now()
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# API key management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def set_api_key(self, alias: str, plaintext_value: str) -> None:
|
||||
"""Store an API key, encrypting it with the master key."""
|
||||
encrypted = encrypt(plaintext_value.encode("utf-8"))
|
||||
async with self._pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.api_keys (alias, encrypted_value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (alias) DO UPDATE SET
|
||||
encrypted_value = EXCLUDED.encrypted_value,
|
||||
rotated_at = now()
|
||||
""",
|
||||
alias,
|
||||
encrypted,
|
||||
)
|
||||
|
||||
async def get_api_key(self, alias: str) -> str | None:
|
||||
"""Retrieve and decrypt an API key by alias."""
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT encrypted_value FROM config.api_keys WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is not None:
|
||||
# Update last_used_at
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.api_keys SET last_used_at = now() WHERE alias = $1
|
||||
""",
|
||||
alias,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return decrypt(row["encrypted_value"]).decode("utf-8")
|
||||
|
||||
async def delete_api_key(self, alias: str) -> bool:
|
||||
"""Delete an API key. Returns True if key existed."""
|
||||
async with self._pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM config.api_keys WHERE alias = $1", alias
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Change notifications
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def listen_for_changes(
|
||||
self,
|
||||
callback: Callable[[str, str], Awaitable[None] | None],
|
||||
) -> None:
|
||||
"""Listen for config changes via Postgres NOTIFY.
|
||||
|
||||
Runs forever, calling callback(table, key) each time a change is
|
||||
detected. The callback can be sync or async.
|
||||
|
||||
On connection loss, automatically reconnects with exponential backoff.
|
||||
Cancellation (via task.cancel()) propagates cleanly.
|
||||
|
||||
Args:
|
||||
callback: Function called with (table_name, row_key) on each change.
|
||||
"""
|
||||
backoff = 1.0
|
||||
max_backoff = 30.0
|
||||
|
||||
while True:
|
||||
conn = None
|
||||
try:
|
||||
conn = await self._pool.acquire()
|
||||
logger.info("Config listener connected to database")
|
||||
backoff = 1.0 # Reset backoff on successful connect
|
||||
|
||||
def notification_handler(
|
||||
conn: asyncpg.Connection,
|
||||
pid: int,
|
||||
channel: str,
|
||||
payload: str,
|
||||
) -> None:
|
||||
# payload format: "table_name:key"
|
||||
if ":" in payload:
|
||||
table, key = payload.split(":", 1)
|
||||
else:
|
||||
table, key = payload, ""
|
||||
|
||||
result = callback(table, key)
|
||||
if asyncio.iscoroutine(result):
|
||||
asyncio.create_task(result)
|
||||
|
||||
await conn.add_listener("config_changed", notification_handler)
|
||||
|
||||
try:
|
||||
# Keep connection alive with periodic keepalive
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
await conn.execute("SELECT 1")
|
||||
finally:
|
||||
await conn.remove_listener("config_changed", notification_handler)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Cancellation must propagate cleanly
|
||||
logger.info("Config listener cancelled")
|
||||
raise
|
||||
|
||||
except (
|
||||
asyncpg.PostgresConnectionError,
|
||||
asyncpg.InterfaceError,
|
||||
ConnectionResetError,
|
||||
OSError,
|
||||
) as e:
|
||||
logger.warning(
|
||||
"Config listener connection lost, reconnecting in %.1fs: %s",
|
||||
backoff,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected error - log and retry with backoff
|
||||
logger.exception(
|
||||
"Config listener unexpected error, reconnecting in %.1fs",
|
||||
backoff,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, max_backoff)
|
||||
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
await self._pool.release(conn)
|
||||
except Exception:
|
||||
pass # Connection may already be invalid
|
||||
111
src/central/crypto.py
Normal file
111
src/central/crypto.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Cryptographic primitives for secret storage.
|
||||
|
||||
Uses AES-256-GCM for authenticated encryption. The master key is read
|
||||
from the path specified in bootstrap config on first use and cached.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
# AES-256 requires 32-byte key
|
||||
KEY_SIZE = 32
|
||||
# GCM nonce size (96 bits recommended by NIST)
|
||||
NONCE_SIZE = 12
|
||||
|
||||
|
||||
class CryptoError(Exception):
|
||||
"""Base exception for crypto operations."""
|
||||
|
||||
|
||||
class KeyLoadError(CryptoError):
|
||||
"""Failed to load master key."""
|
||||
|
||||
|
||||
class DecryptionError(CryptoError):
|
||||
"""Failed to decrypt ciphertext (wrong key or tampered data)."""
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _load_master_key(path: Path) -> bytes:
|
||||
"""Load and decode the base64-encoded master key from file."""
|
||||
try:
|
||||
key_b64 = path.read_text().strip()
|
||||
key = base64.b64decode(key_b64)
|
||||
except FileNotFoundError:
|
||||
raise KeyLoadError(f"Master key file not found: {path}")
|
||||
except Exception as e:
|
||||
raise KeyLoadError(f"Failed to read master key from {path}: {e}")
|
||||
|
||||
if len(key) != KEY_SIZE:
|
||||
raise KeyLoadError(
|
||||
f"Invalid master key size: expected {KEY_SIZE} bytes, got {len(key)}"
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(plaintext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Encrypt plaintext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Ciphertext in format: nonce (12 bytes) || ciphertext || tag (16 bytes)
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = os.urandom(NONCE_SIZE)
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# GCM appends the 16-byte tag to the ciphertext
|
||||
ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, associated_data=None)
|
||||
|
||||
return nonce + ciphertext_with_tag
|
||||
|
||||
|
||||
def decrypt(ciphertext: bytes, key_path: Path | None = None) -> bytes:
|
||||
"""Decrypt ciphertext using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
ciphertext: Data in format: nonce || ciphertext || tag
|
||||
key_path: Path to master key file. If None, uses default from
|
||||
bootstrap config.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext.
|
||||
|
||||
Raises:
|
||||
DecryptionError: If decryption fails (wrong key or tampered data).
|
||||
"""
|
||||
if key_path is None:
|
||||
from central.bootstrap_config import get_settings
|
||||
key_path = get_settings().master_key_path
|
||||
|
||||
if len(ciphertext) < NONCE_SIZE + 16: # nonce + minimum tag
|
||||
raise DecryptionError("Ciphertext too short")
|
||||
|
||||
key = _load_master_key(key_path)
|
||||
nonce = ciphertext[:NONCE_SIZE]
|
||||
ciphertext_with_tag = ciphertext[NONCE_SIZE:]
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
try:
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, associated_data=None)
|
||||
except Exception as e:
|
||||
raise DecryptionError(f"Decryption failed: {e}")
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
def clear_key_cache() -> None:
|
||||
"""Clear the cached master key. Use after key rotation."""
|
||||
_load_master_key.cache_clear()
|
||||
125
src/central/migrate.py
Normal file
125
src/central/migrate.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Simple database migration runner.
|
||||
|
||||
Tracks applied migrations in a `schema_migrations` table. Migrations are
|
||||
plain SQL files in `sql/migrations/` named with numeric prefixes:
|
||||
001_create_config_schema.sql
|
||||
002_add_operators_table.sql
|
||||
...
|
||||
|
||||
Usage:
|
||||
central-migrate [--dry-run]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
|
||||
MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations"
|
||||
|
||||
|
||||
async def ensure_migrations_table(conn: asyncpg.Connection) -> None:
|
||||
"""Create the schema_migrations table if it doesn't exist."""
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
async def get_applied_migrations(conn: asyncpg.Connection) -> set[str]:
|
||||
"""Return set of already-applied migration versions."""
|
||||
rows = await conn.fetch("SELECT version FROM schema_migrations")
|
||||
return {row["version"] for row in rows}
|
||||
|
||||
|
||||
def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]:
|
||||
"""Find all .sql files in migrations directory, sorted by name.
|
||||
|
||||
Returns list of (version, path) tuples where version is the filename
|
||||
without extension.
|
||||
"""
|
||||
if not migrations_dir.exists():
|
||||
return []
|
||||
|
||||
migrations = []
|
||||
for f in sorted(migrations_dir.glob("*.sql")):
|
||||
version = f.stem # e.g., "001_create_config_schema"
|
||||
migrations.append((version, f))
|
||||
return migrations
|
||||
|
||||
|
||||
async def apply_migration(
|
||||
conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False
|
||||
) -> None:
|
||||
"""Apply a single migration."""
|
||||
sql = sql_path.read_text()
|
||||
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] Would apply: {version}")
|
||||
print(f" SQL: {sql[:200]}..." if len(sql) > 200 else f" SQL: {sql}")
|
||||
return
|
||||
|
||||
async with conn.transaction():
|
||||
await conn.execute(sql)
|
||||
await conn.execute(
|
||||
"INSERT INTO schema_migrations (version) VALUES ($1)", version
|
||||
)
|
||||
print(f"Applied: {version}")
|
||||
|
||||
|
||||
async def run_migrations(dsn: str, dry_run: bool = False) -> int:
|
||||
"""Run all pending migrations.
|
||||
|
||||
Returns number of migrations applied.
|
||||
"""
|
||||
conn = await asyncpg.connect(dsn)
|
||||
try:
|
||||
await ensure_migrations_table(conn)
|
||||
applied = await get_applied_migrations(conn)
|
||||
pending = [
|
||||
(v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied
|
||||
]
|
||||
|
||||
if not pending:
|
||||
print("No pending migrations.")
|
||||
return 0
|
||||
|
||||
print(f"Found {len(pending)} pending migration(s).")
|
||||
for version, path in pending:
|
||||
await apply_migration(conn, version, path, dry_run)
|
||||
|
||||
return len(pending)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
"""Async entry point."""
|
||||
parser = argparse.ArgumentParser(description="Run database migrations")
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be applied without executing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
count = await run_migrations(settings.db_dsn, dry_run=args.dry_run)
|
||||
|
||||
if count > 0 and not args.dry_run:
|
||||
print(f"Successfully applied {count} migration(s).")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point."""
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
tests/README.md
Normal file
18
tests/README.md
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Central Tests
|
||||
|
||||
## Test Database
|
||||
|
||||
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
|
||||
By default, tests connect to:
|
||||
|
||||
```
|
||||
postgresql://central_test:testpass@localhost/central_test
|
||||
```
|
||||
|
||||
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
|
||||
environment variable:
|
||||
|
||||
```bash
|
||||
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
|
||||
uv run pytest tests/test_config_store.py
|
||||
```
|
||||
123
tests/test_bootstrap_config.py
Normal file
123
tests/test_bootstrap_config.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Tests for bootstrap configuration."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
from central.bootstrap_config import Settings, get_settings
|
||||
|
||||
|
||||
class TestSettingsFromEnv:
|
||||
"""Test loading settings from environment variables."""
|
||||
|
||||
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Settings are read from CENTRAL_* environment variables."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
|
||||
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
|
||||
assert settings.nats_url == "nats://10.0.0.1:4222"
|
||||
assert settings.master_key_path == Path("/tmp/test.key")
|
||||
assert settings.log_level == "DEBUG"
|
||||
|
||||
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Default values are used when env vars not set."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
|
||||
# Clear any existing env vars that might interfere
|
||||
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.nats_url == "nats://localhost:4222"
|
||||
assert settings.master_key_path == Path("/etc/central/master.key")
|
||||
assert settings.log_level == "INFO"
|
||||
|
||||
|
||||
class TestSettingsFromFile:
|
||||
"""Test loading settings from .env file."""
|
||||
|
||||
def test_reads_from_env_file(self, tmp_path: Path) -> None:
|
||||
"""Settings are read from .env file when env vars not present."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text(
|
||||
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
|
||||
"CENTRAL_NATS_URL=nats://file.local:4222\n"
|
||||
"CENTRAL_LOG_LEVEL=WARNING\n"
|
||||
)
|
||||
|
||||
# Create settings pointing to the temp .env file
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
|
||||
assert settings.nats_url == "nats://file.local:4222"
|
||||
assert settings.log_level == "WARNING"
|
||||
|
||||
def test_env_vars_override_file(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Environment variables take precedence over .env file."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
|
||||
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://env@localhost/envdb"
|
||||
|
||||
|
||||
class TestSettingsValidation:
|
||||
"""Test settings validation and error handling."""
|
||||
|
||||
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Clear error when required CENTRAL_DB_DSN is missing."""
|
||||
# Ensure no env vars or .env file provides the DSN
|
||||
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# Use a non-existent .env file path to ensure no fallback
|
||||
Settings(_env_file=Path("/nonexistent/.env"))
|
||||
|
||||
# pydantic-settings raises ValidationError for missing required fields
|
||||
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
|
||||
|
||||
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Invalid log level values are rejected."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
Settings()
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Test the cached settings loader."""
|
||||
|
||||
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""get_settings() returns cached instance."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
|
||||
s1 = get_settings()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1 is s2
|
||||
|
||||
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""cache_clear() forces reload on next call."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s1 = get_settings()
|
||||
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1.db_dsn != s2.db_dsn
|
||||
339
tests/test_config_store.py
Normal file
339
tests/test_config_store.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""Tests for database-backed configuration store.
|
||||
|
||||
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
|
||||
environment variable to override the default test database connection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN - uses central_test database with well-known test password.
|
||||
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
# Create schema if not exists
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
|
||||
# Create tables if not exist
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.api_keys (
|
||||
alias TEXT PRIMARY KEY,
|
||||
encrypted_value BYTEA NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
rotated_at TIMESTAMPTZ,
|
||||
last_used_at TIMESTAMPTZ
|
||||
)
|
||||
""")
|
||||
|
||||
# Create notify function with proper key detection
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
|
||||
# Create triggers if not exist
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
|
||||
CREATE TRIGGER api_keys_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
|
||||
# Clean tables
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
await db_conn.execute("DELETE FROM config.api_keys")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestAdapterConfig:
|
||||
"""Tests for adapter configuration operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
|
||||
"""Can insert and retrieve adapter config."""
|
||||
await config_store.upsert_adapter(
|
||||
name="test_adapter",
|
||||
enabled=True,
|
||||
cadence_s=120,
|
||||
settings={"key": "value"},
|
||||
)
|
||||
|
||||
adapter = await config_store.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.enabled is True
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent adapter returns None."""
|
||||
adapter = await config_store.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_adapters(self, config_store: ConfigStore) -> None:
|
||||
"""Can list all adapters."""
|
||||
await config_store.upsert_adapter("adapter_a", True, 60, {})
|
||||
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
|
||||
|
||||
adapters = await config_store.list_adapters()
|
||||
|
||||
assert len(adapters) == 2
|
||||
names = [a.name for a in adapters]
|
||||
assert "adapter_a" in names
|
||||
assert "adapter_b" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
|
||||
"""Upsert updates existing adapter."""
|
||||
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
|
||||
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
|
||||
|
||||
adapter = await config_store.get_adapter("updater")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.enabled is False
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"v": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
|
||||
"""Can pause and unpause adapter."""
|
||||
await config_store.upsert_adapter("pausable", True, 60, {})
|
||||
|
||||
await config_store.pause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is True
|
||||
|
||||
await config_store.unpause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is False
|
||||
|
||||
|
||||
class TestApiKeys:
|
||||
"""Tests for API key operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can store and retrieve encrypted API key."""
|
||||
await config_store.set_api_key("test_key", "super_secret_value")
|
||||
|
||||
value = await config_store.get_api_key("test_key")
|
||||
|
||||
assert value == "super_secret_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent key returns None."""
|
||||
value = await config_store.get_api_key("does_not_exist")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_rotation(self, config_store: ConfigStore) -> None:
|
||||
"""Updating key sets rotated_at."""
|
||||
await config_store.set_api_key("rotate_me", "value1")
|
||||
await config_store.set_api_key("rotate_me", "value2")
|
||||
|
||||
value = await config_store.get_api_key("rotate_me")
|
||||
assert value == "value2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can delete API key."""
|
||||
await config_store.set_api_key("delete_me", "value")
|
||||
|
||||
deleted = await config_store.delete_api_key("delete_me")
|
||||
assert deleted is True
|
||||
|
||||
value = await config_store.get_api_key("delete_me")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Deleting nonexistent key returns False."""
|
||||
deleted = await config_store.delete_api_key("never_existed")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestNotifications:
|
||||
"""Tests for LISTEN/NOTIFY functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when adapter is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start listener in background
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Give listener time to subscribe
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Trigger a change
|
||||
await config_store.upsert_adapter("notify_test", True, 60, {})
|
||||
|
||||
# Wait for notification (with timeout)
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "adapters"
|
||||
assert notifications[0][1] == "notify_test"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when API key is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
await config_store.set_api_key("notify_key", "secret")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "api_keys"
|
||||
assert notifications[0][1] == "notify_key"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestListenerReconnect:
|
||||
"""Tests for listener reconnection on connection loss."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_cancellation_propagates(
|
||||
self, config_store: ConfigStore
|
||||
) -> None:
|
||||
"""Cancellation cleanly stops the listener without reconnect loop."""
|
||||
async def callback(table: str, key: str) -> None:
|
||||
pass
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
# Give listener time to start
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cancel and verify it stops
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(listen_task, timeout=2.0)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Listener did not stop after cancellation")
|
||||
|
||||
assert listen_task.cancelled() or listen_task.done()
|
||||
175
tests/test_crypto.py
Normal file
175
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Tests for cryptographic primitives."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from central.crypto import (
|
||||
KEY_SIZE,
|
||||
DecryptionError,
|
||||
KeyLoadError,
|
||||
clear_key_cache,
|
||||
decrypt,
|
||||
encrypt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def master_key(tmp_path: Path) -> Path:
|
||||
"""Create a valid master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
clear_key_cache()
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_key(tmp_path: Path) -> Path:
|
||||
"""Create a different master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "wrong.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""Test encrypt/decrypt round-trip."""
|
||||
|
||||
def test_round_trip(self, master_key: Path) -> None:
|
||||
"""Encrypting then decrypting returns original plaintext."""
|
||||
plaintext = b"Hello, Central!"
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_empty(self, master_key: Path) -> None:
|
||||
"""Empty plaintext encrypts and decrypts correctly."""
|
||||
plaintext = b""
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_large(self, master_key: Path) -> None:
|
||||
"""Large plaintext encrypts and decrypts correctly."""
|
||||
plaintext = os.urandom(1024 * 1024) # 1MB
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
|
||||
"""Same plaintext produces different ciphertext (random nonce)."""
|
||||
plaintext = b"test"
|
||||
|
||||
ct1 = encrypt(plaintext, key_path=master_key)
|
||||
ct2 = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
assert ct1 != ct2
|
||||
# But both decrypt to same plaintext
|
||||
assert decrypt(ct1, key_path=master_key) == plaintext
|
||||
assert decrypt(ct2, key_path=master_key) == plaintext
|
||||
|
||||
|
||||
class TestDecryptionFailures:
|
||||
"""Test AEAD authentication catches tampering."""
|
||||
|
||||
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
|
||||
"""Decryption with wrong key raises DecryptionError."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
clear_key_cache() # Clear cache so wrong_key is loaded
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(ciphertext, key_path=wrong_key)
|
||||
|
||||
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Modified ciphertext is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the ciphertext (after nonce, before tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[15] ^= 0x01 # Flip one bit
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_tampered_tag_fails(self, master_key: Path) -> None:
|
||||
"""Modified authentication tag is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the last byte (part of the tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[-1] ^= 0x01
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Truncated ciphertext is rejected."""
|
||||
ciphertext = b"tooshort"
|
||||
|
||||
with pytest.raises(DecryptionError, match="too short"):
|
||||
decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
|
||||
class TestKeyLoading:
|
||||
"""Test master key loading."""
|
||||
|
||||
def test_missing_key_file(self, tmp_path: Path) -> None:
|
||||
"""Missing key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
missing = tmp_path / "nonexistent.key"
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=missing)
|
||||
|
||||
def test_invalid_key_size(self, tmp_path: Path) -> None:
|
||||
"""Key file with wrong size raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text(base64.b64encode(b"tooshort").decode())
|
||||
|
||||
with pytest.raises(KeyLoadError, match="Invalid master key size"):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_invalid_base64(self, tmp_path: Path) -> None:
|
||||
"""Invalid base64 in key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text("not valid base64!!!")
|
||||
|
||||
with pytest.raises(KeyLoadError):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_key_cached(self, master_key: Path) -> None:
|
||||
"""Key is cached after first load."""
|
||||
# First encryption loads the key
|
||||
encrypt(b"test1", key_path=master_key)
|
||||
|
||||
# Delete the file
|
||||
master_key.unlink()
|
||||
|
||||
# Second encryption should still work (cached)
|
||||
ciphertext = encrypt(b"test2", key_path=master_key)
|
||||
assert len(ciphertext) > 0
|
||||
|
||||
def test_cache_clear(self, master_key: Path) -> None:
|
||||
"""clear_key_cache forces reload."""
|
||||
encrypt(b"test", key_path=master_key)
|
||||
master_key.unlink()
|
||||
clear_key_cache()
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=master_key)
|
||||
|
|
@ -49,7 +49,7 @@ def sample_config() -> Config:
|
|||
},
|
||||
cloudevents=CloudEventsConfig(
|
||||
type_prefix="central",
|
||||
source="central.echo6.mesh",
|
||||
source="central.local",
|
||||
schema_version="1.0",
|
||||
),
|
||||
nats=NATSConfig(url="nats://localhost:4222"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue