diff --git a/src/central/migrate.py b/src/central/migrate.py index 908e7e7..98b684d 100644 --- a/src/central/migrate.py +++ b/src/central/migrate.py @@ -8,15 +8,19 @@ plain SQL files in `sql/migrations/` named with numeric prefixes: Usage: central-migrate [--dry-run] + central-migrate --check # report drift; exit 1 if any file is untracked """ import argparse import asyncio +import logging import sys from pathlib import Path import asyncpg +logger = logging.getLogger(__name__) + MIGRATIONS_DIR = Path(__file__).parent.parent.parent / "sql" / "migrations" @@ -52,6 +56,47 @@ def discover_migrations(migrations_dir: Path) -> list[tuple[str, Path]]: return migrations +def find_drift( + applied: set[str], discovered: list[tuple[str, Path]] +) -> tuple[list[str], list[str]]: + """Compare schema_migrations rows against migration files on disk. + + Returns (untracked, orphan): + - untracked: migration files present on disk with NO schema_migrations row. + This is the v0.9.18 failure mode -- migrations applied out-of-band (direct + psql) that were never recorded, so a fresh restore would try to replay + them and an audit of the tracking table understates what is live. + - orphan: schema_migrations versions with NO matching .sql file (e.g. a + migration file removed from the repo). + + Pure function (no I/O) so the drift logic is unit-testable without a database. + """ + disk_versions = {v for v, _ in discovered} + untracked = sorted(v for v in disk_versions if v not in applied) + orphan = sorted(v for v in applied if v not in disk_versions) + return untracked, orphan + + +def log_drift(untracked: list[str], orphan: list[str]) -> None: + """Emit a WARN per drifted migration so divergence is visible in logs. + + Cheap insurance against drift compounding silently (v0.9.18): any file not + recorded in schema_migrations, or any row with no file, is surfaced on every + central-migrate run. + """ + for version in untracked: + logger.warning( + "migration file %s is not recorded in schema_migrations " + "(pending apply, or applied out-of-band -- investigate)", + version, + ) + for version in orphan: + logger.warning( + "schema_migrations records %s but no matching .sql file exists on disk", + version, + ) + + async def apply_migration( conn: asyncpg.Connection, version: str, sql_path: Path, dry_run: bool = False ) -> None: @@ -80,9 +125,12 @@ async def run_migrations(dsn: str, dry_run: bool = False) -> int: try: await ensure_migrations_table(conn) applied = await get_applied_migrations(conn) - pending = [ - (v, p) for v, p in discover_migrations(MIGRATIONS_DIR) if v not in applied - ] + discovered = discover_migrations(MIGRATIONS_DIR) + + untracked, orphan = find_drift(applied, discovered) + log_drift(untracked, orphan) + + pending = [(v, p) for v, p in discovered if v not in applied] if not pending: print("No pending migrations.") @@ -97,19 +145,57 @@ async def run_migrations(dsn: str, dry_run: bool = False) -> int: await conn.close() +async def check_drift(dsn: str) -> int: + """Report migration drift; return 1 if any file is untracked, else 0. + + The CI-assertion form of log_drift: run against a DB expected to be fully + reconciled (after a restore, or in CI). A non-zero exit means a migration + file was never recorded -- the v0.9.18 drift, made loud and scriptable. + """ + conn = await asyncpg.connect(dsn) + try: + await ensure_migrations_table(conn) + applied = await get_applied_migrations(conn) + untracked, orphan = find_drift(applied, discover_migrations(MIGRATIONS_DIR)) + finally: + await conn.close() + + for version in untracked: + print(f"DRIFT: {version} present on disk but not in schema_migrations") + for version in orphan: + print(f"DRIFT: {version} in schema_migrations but no .sql file on disk") + + if untracked: + print(f"{len(untracked)} untracked migration file(s) -- drift detected.") + return 1 + print("No migration drift detected.") + return 0 + + async def async_main() -> None: """Async entry point.""" + logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Run database migrations") parser.add_argument( "--dry-run", action="store_true", help="Show what would be applied without executing", ) + parser.add_argument( + "--check", + action="store_true", + help="Report drift and exit 1 if any migration file is untracked " + "(does not apply migrations)", + ) args = parser.parse_args() from central.bootstrap_config import get_settings settings = get_settings() + + if args.check: + sys.exit(await check_drift(settings.db_dsn)) + count = await run_migrations(settings.db_dsn, dry_run=args.dry_run) if count > 0 and not args.dry_run: diff --git a/tests/test_migrate.py b/tests/test_migrate.py new file mode 100644 index 0000000..5fc70b3 --- /dev/null +++ b/tests/test_migrate.py @@ -0,0 +1,50 @@ +"""Drift-detection tests for the migration runner (v0.9.18). + +`find_drift` is a pure function so the divergence logic is verified without a +database -- the suite runs identically as `zvx` or `central`. +""" + +from pathlib import Path + +from central.migrate import find_drift + + +def _discovered(*versions: str) -> list[tuple[str, Path]]: + """Build a discover_migrations-shaped list from version stems.""" + return [(v, Path(f"sql/migrations/{v}.sql")) for v in versions] + + +def test_find_drift_untracked_file(): + """A migration file on disk with no schema_migrations row is untracked. + + This is the v0.9.18 case: 025-029 existed on disk but were never recorded. + """ + applied = {"001_a", "002_b"} + discovered = _discovered("001_a", "002_b", "003_c") + + untracked, orphan = find_drift(applied, discovered) + + assert untracked == ["003_c"] + assert orphan == [] + + +def test_find_drift_orphan_row(): + """A schema_migrations row with no matching .sql file is an orphan.""" + applied = {"001_a", "002_b", "099_removed"} + discovered = _discovered("001_a", "002_b") + + untracked, orphan = find_drift(applied, discovered) + + assert untracked == [] + assert orphan == ["099_removed"] + + +def test_find_drift_clean(): + """No drift when disk and schema_migrations match exactly.""" + applied = {"001_a", "002_b", "003_c"} + discovered = _discovered("001_a", "002_b", "003_c") + + untracked, orphan = find_drift(applied, discovered) + + assert untracked == [] + assert orphan == []