From f059f982bcd459864e60e2e601a68ea69ec5cc53 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 17 May 2026 05:30:49 +0000 Subject: [PATCH] feat(gui): add auth core, setup gate, and first-run operator creation - Add migrations 007-010 for system config, operators, sessions, audit_log - Implement argon2id password hashing via argon2-cffi - Implement session-based authentication with database-stored tokens - Add SetupGateMiddleware to redirect to /setup until first operator created - Add SessionMiddleware to load session from cookie and attach operator - Create /setup, /login, /logout, /change-password routes with CSRF protection - Add periodic session cleanup task (hourly) - Add audit logging for auth events - Update systemd unit with EnvironmentFile for /etc/central/central.env - Add comprehensive tests for auth, middleware, and audit modules Co-Authored-By: Claude Opus 4.5 --- .gitignore | 1 + docs/environment.md | 28 ++ pyproject.toml | 2 + sql/migrations/007_add_config_system.sql | 21 ++ sql/migrations/008_add_operators.sql | 10 + sql/migrations/009_add_sessions.sql | 11 + sql/migrations/010_add_audit_log.sql | 15 + src/central/bootstrap_config.py | 3 + src/central/gui/__init__.py | 156 +++++++- src/central/gui/audit.py | 37 ++ src/central/gui/auth.py | 138 +++++++ src/central/gui/db.py | 48 +++ src/central/gui/middleware.py | 96 +++++ src/central/gui/routes.py | 352 +++++++++++++++++- src/central/gui/templates/base.html | 29 ++ .../gui/templates/change_password.html | 36 ++ src/central/gui/templates/login.html | 29 ++ src/central/gui/templates/setup.html | 37 ++ systemd/central-gui.service | 4 + tests/conftest.py | 50 +++ tests/test_audit.py | 92 +++++ tests/test_auth.py | 183 +++++++++ tests/test_session_auth.py | 173 +++++++++ tests/test_setup_gate.py | 162 ++++++++ uv.lock | 61 +++ 25 files changed, 1758 insertions(+), 16 deletions(-) create mode 100644 sql/migrations/007_add_config_system.sql create mode 100644 sql/migrations/008_add_operators.sql create mode 100644 sql/migrations/009_add_sessions.sql create mode 100644 sql/migrations/010_add_audit_log.sql create mode 100644 src/central/gui/audit.py create mode 100644 src/central/gui/auth.py create mode 100644 src/central/gui/db.py create mode 100644 src/central/gui/middleware.py create mode 100644 src/central/gui/templates/change_password.html create mode 100644 src/central/gui/templates/login.html create mode 100644 src/central/gui/templates/setup.html create mode 100644 tests/conftest.py create mode 100644 tests/test_audit.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_session_auth.py create mode 100644 tests/test_setup_gate.py diff --git a/.gitignore b/.gitignore index a61310f..04e7f06 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ db.env .vscode/ *.swp .DS_Store +.ssh/ diff --git a/docs/environment.md b/docs/environment.md index 7396362..9659443 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -68,6 +68,34 @@ journalctl -u central-archive -f ## Database +## Environment Variables + +Environment variables are stored in `/etc/central/central.env` and loaded by +systemd services via `EnvironmentFile=`. + +| Variable | Required | Description | +|----------|----------|-------------| +| `CENTRAL_CSRF_SECRET` | Yes (for GUI) | Secret key for CSRF token signing. Generate with `python3 -c "import secrets; print(secrets.token_urlsafe(32))"` | + +### Generating CSRF Secret + +```bash +python3 -c "import secrets; print(secrets.token_urlsafe(32))" +``` + +Add the generated value to `/etc/central/central.env`: + +```bash +CENTRAL_CSRF_SECRET= +``` + +Ensure the file has restricted permissions: + +```bash +sudo chmod 640 /etc/central/central.env +sudo chown central:central /etc/central/central.env +``` + PostgreSQL 16 with TimescaleDB runs on CT104: ```bash diff --git a/pyproject.toml b/pyproject.toml index 476998b..d1d7f8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,11 @@ license = {text = "MIT"} authors = [{name = "Matt Johnson"}] dependencies = [ "aiohttp>=3.13.5", + "argon2-cffi>=25.1.0", "asyncpg>=0.31.0", "cloudevents>=2.0.0", "cryptography>=44.0.0", + "fastapi-csrf-protect>=0.4.0", "fastapi>=0.115.0", "jinja2>=3.1.6", "nats-py>=2.14.0", diff --git a/sql/migrations/007_add_config_system.sql b/sql/migrations/007_add_config_system.sql new file mode 100644 index 0000000..6c0b80b --- /dev/null +++ b/sql/migrations/007_add_config_system.sql @@ -0,0 +1,21 @@ +-- Migration 007: Add config.system table for global settings +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.system ( + id BOOLEAN PRIMARY KEY DEFAULT true CHECK (id = true), + setup_complete BOOLEAN NOT NULL DEFAULT false, + session_lifetime_days INTEGER NOT NULL DEFAULT 90, + map_tile_url TEXT NOT NULL DEFAULT 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', + map_attribution TEXT NOT NULL DEFAULT '© OpenStreetMap contributors', + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Reuse existing set_updated_at trigger function +DROP TRIGGER IF EXISTS system_set_updated_at ON config.system; +CREATE TRIGGER system_set_updated_at + BEFORE UPDATE ON config.system + FOR EACH ROW + EXECUTE FUNCTION config.set_updated_at(); + +-- Seed single row +INSERT INTO config.system (id) VALUES (true) ON CONFLICT DO NOTHING; diff --git a/sql/migrations/008_add_operators.sql b/sql/migrations/008_add_operators.sql new file mode 100644 index 0000000..2886caa --- /dev/null +++ b/sql/migrations/008_add_operators.sql @@ -0,0 +1,10 @@ +-- Migration 008: Add config.operators table for user accounts +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.operators ( + id BIGSERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + password_changed_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/sql/migrations/009_add_sessions.sql b/sql/migrations/009_add_sessions.sql new file mode 100644 index 0000000..9016676 --- /dev/null +++ b/sql/migrations/009_add_sessions.sql @@ -0,0 +1,11 @@ +-- Migration 009: Add config.sessions table for auth tokens +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.sessions ( + token TEXT PRIMARY KEY, + operator_id BIGINT NOT NULL REFERENCES config.operators(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + expires_at TIMESTAMPTZ NOT NULL +); + +CREATE INDEX IF NOT EXISTS sessions_expires_at_idx ON config.sessions(expires_at); diff --git a/sql/migrations/010_add_audit_log.sql b/sql/migrations/010_add_audit_log.sql new file mode 100644 index 0000000..d6baa65 --- /dev/null +++ b/sql/migrations/010_add_audit_log.sql @@ -0,0 +1,15 @@ +-- Migration 010: Add config.audit_log table +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.audit_log ( + id BIGSERIAL PRIMARY KEY, + ts TIMESTAMPTZ NOT NULL DEFAULT now(), + operator_id BIGINT REFERENCES config.operators(id) ON DELETE SET NULL, + action TEXT NOT NULL, + target TEXT, + before JSONB, + after JSONB +); + +CREATE INDEX IF NOT EXISTS audit_log_ts_idx ON config.audit_log(ts DESC); +CREATE INDEX IF NOT EXISTS audit_log_action_idx ON config.audit_log(action); diff --git a/src/central/bootstrap_config.py b/src/central/bootstrap_config.py index f5e46bc..334c4c6 100644 --- a/src/central/bootstrap_config.py +++ b/src/central/bootstrap_config.py @@ -33,6 +33,9 @@ class Settings(BaseSettings): default="INFO", description="Logging level", ) + csrf_secret: str = Field( + description="Secret key for CSRF token signing (generate with: python -c \"import secrets; print(secrets.token_urlsafe(32))\")", + ) @lru_cache diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 807bae7..f7ee746 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -1,13 +1,17 @@ """Central GUI — FastAPI + Jinja2 + HTMX.""" +import asyncio +import logging +from contextlib import asynccontextmanager from pathlib import Path +from typing import Any import uvicorn from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from central.gui.routes import router +logger = logging.getLogger(__name__) # Template and static directories GUI_DIR = Path(__file__).parent @@ -17,17 +21,108 @@ STATIC_DIR = GUI_DIR / "static" # Jinja2 templates instance (shared with routes) templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) +# Shutdown event and cleanup task +_shutdown_event: asyncio.Event | None = None +_cleanup_task: asyncio.Task | None = None + +# Lazy app singleton +_app: FastAPI | None = None + + +def _configure_csrf() -> None: + """Configure CSRF protection. Must be called before app starts.""" + from fastapi_csrf_protect import CsrfProtect + from pydantic import BaseModel + from central.bootstrap_config import get_settings + + class CsrfSettings(BaseModel): + secret_key: str + + @CsrfProtect.load_config + def get_csrf_config(): + settings = get_settings() + return CsrfSettings(secret_key=settings.csrf_secret) + + +async def _session_cleanup_loop() -> None: + """Periodically clean up expired sessions.""" + global _shutdown_event + + from central.gui.db import get_pool + + if _shutdown_event is None: + return + + while not _shutdown_event.is_set(): + try: + await asyncio.wait_for(_shutdown_event.wait(), timeout=3600) + except asyncio.TimeoutError: + try: + pool = get_pool() + if pool: + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM config.sessions WHERE expires_at < now()" + ) + deleted = result.split()[-1] if result else "0" + if int(deleted) > 0: + logger.info("Session cleanup", extra={"deleted": deleted}) + except Exception: + logger.warning("Session cleanup failed", exc_info=True) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + global _shutdown_event, _cleanup_task + + from central.bootstrap_config import get_settings + from central.gui.db import close_pool, init_pool + + settings = get_settings() + + # Initialize database pool + await init_pool(settings.db_dsn) + + # Start session cleanup task + _shutdown_event = asyncio.Event() + _cleanup_task = asyncio.create_task(_session_cleanup_loop()) + + logger.info("Central GUI started") + + yield + + # Shutdown + if _shutdown_event: + _shutdown_event.set() + if _cleanup_task: + try: + await asyncio.wait_for(_cleanup_task, timeout=5.0) + except asyncio.TimeoutError: + _cleanup_task.cancel() + + await close_pool() + logger.info("Central GUI stopped") + + +def _create_app() -> FastAPI: + """Create the FastAPI application.""" + from central.gui.middleware import SessionMiddleware, SetupGateMiddleware + from central.gui.routes import router + + # Configure CSRF before creating app + _configure_csrf() -def create_app() -> FastAPI: - """Create and configure the FastAPI application.""" app = FastAPI( - title="Central", - description="Central Data Hub GUI", - docs_url=None, # Disable Swagger UI for now - redoc_url=None, # Disable ReDoc for now + title="Central GUI", + lifespan=lifespan, ) - # Mount static files if directory exists and has content + # Add middleware (order matters - first added runs last) + app.add_middleware(SessionMiddleware) + app.add_middleware(SetupGateMiddleware) + + # Mount static files if STATIC_DIR.exists(): app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") @@ -37,10 +132,47 @@ def create_app() -> FastAPI: return app -# Application instance -app = create_app() +def __getattr__(name: str) -> Any: + """Lazy attribute access for app singleton.""" + global _app + if name == "app": + if _app is None: + _app = _create_app() + return _app + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def main() -> None: - """Entry point for central-gui console script.""" - uvicorn.run(app, host="127.0.0.1", port=8000) + """Entry point for central-gui command.""" + import logging.config + + logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s %(levelname)s %(name)s: %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "root": { + "level": "INFO", + "handlers": ["console"], + }, + }) + + uvicorn.run( + "central.gui:app", + host="0.0.0.0", + port=8088, + reload=False, + ) + + +if __name__ == "__main__": + main() diff --git a/src/central/gui/audit.py b/src/central/gui/audit.py new file mode 100644 index 0000000..428275a --- /dev/null +++ b/src/central/gui/audit.py @@ -0,0 +1,37 @@ +"""Audit logging for Central GUI.""" + +import json +from typing import Any + +# Audit action constants +AUTH_LOGIN = "auth.login" +AUTH_LOGIN_FAILED = "auth.login_failed" +AUTH_LOGOUT = "auth.logout" +AUTH_PASSWORD_CHANGE = "auth.password_change" +OPERATOR_CREATE = "operator.create" + + +async def write_audit( + conn: Any, # asyncpg.Connection + action: str, + operator_id: int | None = None, + target: str | None = None, + before: dict[str, Any] | None = None, + after: dict[str, Any] | None = None, +) -> None: + """Write an audit log entry.""" + # Serialize before/after as JSON strings if provided + before_json = json.dumps(before) if before else None + after_json = json.dumps(after) if after else None + + await conn.execute( + """ + INSERT INTO config.audit_log (operator_id, action, target, before, after) + VALUES ($1, $2, $3, $4::jsonb, $5::jsonb) + """, + operator_id, + action, + target, + before_json, + after_json, + ) diff --git a/src/central/gui/auth.py b/src/central/gui/auth.py new file mode 100644 index 0000000..3b74ac0 --- /dev/null +++ b/src/central/gui/auth.py @@ -0,0 +1,138 @@ +"""Authentication utilities for Central GUI.""" + +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +from argon2 import PasswordHasher +from argon2.exceptions import VerifyMismatchError + +# Use argon2-cffi defaults (argon2id) +_hasher = PasswordHasher() + + +@dataclass +class Operator: + """Operator account.""" + id: int + username: str + created_at: datetime + password_changed_at: datetime | None = None + + +def hash_password(plain: str) -> str: + """Hash a password using argon2id.""" + return _hasher.hash(plain) + + +def verify_password(plain: str, hashed: str) -> bool: + """Verify a password against its hash.""" + try: + _hasher.verify(hashed, plain) + return True + except VerifyMismatchError: + return False + + +def validate_password(plain: str) -> None: + """Validate password meets requirements. Raises ValueError if invalid.""" + if len(plain) < 8: + raise ValueError("Password must be at least 8 characters") + + +def generate_token() -> str: + """Generate a cryptographically secure session token.""" + return secrets.token_urlsafe(32) + + +async def create_session( + conn: Any, # asyncpg.Connection + operator_id: int, + lifetime_days: int, +) -> tuple[str, datetime]: + """Create a new session for an operator. + + Returns (token, expires_at). + """ + token = generate_token() + expires_at = datetime.now(timezone.utc) + timedelta(days=lifetime_days) + + await conn.execute( + """ + INSERT INTO config.sessions (token, operator_id, expires_at) + VALUES ($1, $2, $3) + """, + token, + operator_id, + expires_at, + ) + + return token, expires_at + + +async def get_session(conn: Any, token: str) -> Operator | None: + """Look up a session and return the associated operator. + + Returns None if token is invalid or expired. + """ + row = await conn.fetchrow( + """ + SELECT o.id, o.username, o.created_at, o.password_changed_at + FROM config.sessions s + JOIN config.operators o ON s.operator_id = o.id + WHERE s.token = $1 AND s.expires_at > now() + """, + token, + ) + + if row is None: + return None + + return Operator( + id=row["id"], + username=row["username"], + created_at=row["created_at"], + password_changed_at=row.get("password_changed_at"), + ) + + +async def delete_session(conn: Any, token: str) -> None: + """Delete a session.""" + await conn.execute( + "DELETE FROM config.sessions WHERE token = $1", + token, + ) + + +async def get_operator_by_username(conn: Any, username: str) -> dict | None: + """Get an operator by username. + + Returns the row dict or None if not found. + """ + return await conn.fetchrow( + """ + SELECT id, username, password_hash, created_at, password_changed_at + FROM config.operators + WHERE username = $1 + """, + username, + ) + + +async def create_operator(conn: Any, username: str, password: str) -> int: + """Create a new operator. + + Returns the new operator ID. + """ + password_hash = hash_password(password) + row = await conn.fetchval( + """ + INSERT INTO config.operators (username, password_hash) + VALUES ($1, $2) + RETURNING id + """, + username, + password_hash, + ) + return row diff --git a/src/central/gui/db.py b/src/central/gui/db.py new file mode 100644 index 0000000..645fdc7 --- /dev/null +++ b/src/central/gui/db.py @@ -0,0 +1,48 @@ +"""Database connection pool for GUI.""" + +import json +from typing import Any + +import asyncpg + +# Module-level pool instance +_pool: asyncpg.Pool | None = None + + +# TODO: Deduplicate with central.config_store._setup_json_codec +async def _setup_json_codec(conn: asyncpg.Connection) -> None: + """Set up JSON codec for asyncpg connection.""" + await conn.set_type_codec( + "jsonb", + encoder=json.dumps, + decoder=json.loads, + schema="pg_catalog", + ) + + +async def init_pool(dsn: str) -> asyncpg.Pool: + """Initialize the connection pool.""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool( + dsn, + min_size=1, + max_size=5, + init=_setup_json_codec, + ) + return _pool + + +def get_pool() -> asyncpg.Pool: + """Get the connection pool. Must call init_pool first.""" + if _pool is None: + raise RuntimeError("Database pool not initialized. Call init_pool first.") + return _pool + + +async def close_pool() -> None: + """Close the connection pool.""" + global _pool + if _pool is not None: + await _pool.close() + _pool = None diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py new file mode 100644 index 0000000..e451528 --- /dev/null +++ b/src/central/gui/middleware.py @@ -0,0 +1,96 @@ +"""Middleware for Central GUI.""" + +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from central.gui.auth import get_session +from central.gui.db import get_pool + +logger = logging.getLogger(__name__) + +# Paths that don't require setup to be complete +SETUP_EXEMPT_PATHS = {"/setup", "/health"} +SETUP_EXEMPT_PREFIXES = ("/static/",) + +# Paths that don't require authentication +AUTH_EXEMPT_PATHS = {"/setup", "/login", "/health"} +AUTH_EXEMPT_PREFIXES = ("/static/",) + + +def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: + """Check if a path is exempt from a check.""" + if path in exempt_paths: + return True + for prefix in exempt_prefixes: + if path.startswith(prefix): + return True + return False + + +class SetupGateMiddleware(BaseHTTPMiddleware): + """Redirect to /setup if setup is not complete.""" + + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + + # Check setup status from database + pool = get_pool() + if pool is None: + # Pool not initialized yet + return await call_next(request) + + setup_complete = False + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT setup_complete FROM config.system WHERE id = true" + ) + setup_complete = row["setup_complete"] if row else False + except Exception: + logger.warning("Failed to check setup status", exc_info=True) + # On error, allow the request through + return await call_next(request) + + if not setup_complete: + # Setup not complete - only allow exempt paths + if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES): + return RedirectResponse(url="/setup", status_code=307) + else: + # Setup complete - redirect /setup to / + if path == "/setup": + return RedirectResponse(url="/", status_code=302) + + return await call_next(request) + + +class SessionMiddleware(BaseHTTPMiddleware): + """Load session from cookie and attach operator to request.state.""" + + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + + # Initialize operator to None + request.state.operator = None + + # Try to load session from cookie + session_token = request.cookies.get("central_session") + if session_token: + pool = get_pool() + if pool is not None: + try: + async with pool.acquire() as conn: + operator = await get_session(conn, session_token) + request.state.operator = operator + except Exception: + logger.warning("Failed to load session", exc_info=True) + request.state.operator = None + + # Check if auth is required + if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): + if request.state.operator is None: + return RedirectResponse(url="/login", status_code=302) + + return await call_next(request) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index e06ee2d..19a8b6d 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -1,11 +1,60 @@ """Route handlers for Central GUI.""" -from fastapi import APIRouter, Request -from fastapi.responses import HTMLResponse +from fastapi import APIRouter, Depends, Form, Request +from fastapi.responses import HTMLResponse, RedirectResponse, Response +from fastapi_csrf_protect import CsrfProtect + +from central.gui.auth import ( + create_session, + delete_session, + hash_password, + validate_password, + verify_password, +) +from central.gui.audit import ( + AUTH_LOGIN, + AUTH_LOGIN_FAILED, + AUTH_LOGOUT, + AUTH_PASSWORD_CHANGE, + OPERATOR_CREATE, + write_audit, +) +from central.gui.db import get_pool router = APIRouter() +def _get_templates(): + """Get templates instance (deferred import to avoid circular).""" + from central.gui import templates + return templates + + +def _set_session_cookie( + response: Response, + token: str, + max_age: int, +) -> None: + """Set the session cookie on a response.""" + response.set_cookie( + key="central_session", + value=token, + httponly=True, + samesite="lax", + secure=False, + max_age=max_age, + path="/", + ) + + +def _clear_session_cookie(response: Response) -> None: + """Clear the session cookie.""" + response.delete_cookie( + key="central_session", + path="/", + ) + + @router.get("/health") async def health() -> dict: """Health check endpoint.""" @@ -15,9 +64,304 @@ async def health() -> dict: @router.get("/", response_class=HTMLResponse) async def index(request: Request) -> HTMLResponse: """Render the index page.""" - from central.gui import templates - + templates = _get_templates() return templates.TemplateResponse( request=request, name="index.html", ) + + +@router.get("/setup", response_class=HTMLResponse) +async def setup_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the setup form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup.html", + context={"csrf_token": signed_token, "error": None}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup") +async def setup_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + confirm_password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the setup form.""" + templates = _get_templates() + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Validate input + error = None + if password != confirm_password: + error = "Passwords do not match" + else: + try: + validate_password(password) + except ValueError as e: + error = str(e) + + if error: + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup.html", + context={"csrf_token": signed_token, "error": error}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Create operator + password_hash = hash_password(password) + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO config.operators (username, password_hash) + VALUES ($1, $2) + RETURNING id + """, + username, + password_hash, + ) + operator_id = row["id"] + + # Write audit log + await write_audit( + conn, + OPERATOR_CREATE, + operator_id=operator_id, + target=username, + ) + + # Get session lifetime + sysrow = await conn.fetchrow( + "SELECT session_lifetime_days FROM config.system WHERE id = true" + ) + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + + # Create session + token, expires_at = await create_session(conn, operator_id, lifetime_days) + + # Mark setup complete + await conn.execute( + "UPDATE config.system SET setup_complete = true WHERE id = true" + ) + + # Redirect with session cookie + response = RedirectResponse(url="/", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.get("/login", response_class=HTMLResponse) +async def login_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the login form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": None}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/login") +async def login_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the login form.""" + templates = _get_templates() + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Look up operator + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, username, password_hash, created_at, password_changed_at + FROM config.operators + WHERE username = $1 + """, + username, + ) + + if row is None: + # Unknown user - still audit the attempt + await write_audit(conn, AUTH_LOGIN_FAILED, target=username) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": "Invalid username or password"}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Verify password + if not verify_password(password, row["password_hash"]): + await write_audit(conn, AUTH_LOGIN_FAILED, operator_id=row["id"], target=username) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": "Invalid username or password"}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Get session lifetime + sysrow = await conn.fetchrow( + "SELECT session_lifetime_days FROM config.system WHERE id = true" + ) + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + + # Create session + token, expires_at = await create_session(conn, row["id"], lifetime_days) + + # Audit login + await write_audit(conn, AUTH_LOGIN, operator_id=row["id"], target=username) + + # Redirect with session cookie + response = RedirectResponse(url="/", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.post("/logout") +async def logout( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Log out the current user.""" + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Get current session + session_token = request.cookies.get("central_session") + operator = getattr(request.state, "operator", None) + + async with pool.acquire() as conn: + if session_token: + await delete_session(conn, session_token) + + if operator: + await write_audit(conn, AUTH_LOGOUT, operator_id=operator.id, target=operator.username) + + response = RedirectResponse(url="/login", status_code=302) + _clear_session_cookie(response) + return response + + +@router.get("/change-password", response_class=HTMLResponse) +async def change_password_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the change password form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="change_password.html", + context={"csrf_token": signed_token, "error": None, "success": False}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/change-password") +async def change_password_submit( + request: Request, + current_password: str = Form(...), + new_password: str = Form(...), + confirm_password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the change password form.""" + templates = _get_templates() + pool = get_pool() + operator = request.state.operator + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Get current password hash + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT password_hash FROM config.operators WHERE id = $1", + operator.id, + ) + + error = None + + # Verify current password + if not verify_password(current_password, row["password_hash"]): + error = "Current password is incorrect" + elif new_password != confirm_password: + error = "New passwords do not match" + else: + try: + validate_password(new_password) + except ValueError as e: + error = str(e) + + if error: + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="change_password.html", + context={"csrf_token": signed_token, "error": error, "success": False}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Update password + new_hash = hash_password(new_password) + await conn.execute( + """ + UPDATE config.operators + SET password_hash = $1, password_changed_at = now() + WHERE id = $2 + """, + new_hash, + operator.id, + ) + + # Audit + await write_audit( + conn, + AUTH_PASSWORD_CHANGE, + operator_id=operator.id, + target=operator.username, + ) + + # Redirect to index + return RedirectResponse(url="/", status_code=302) diff --git a/src/central/gui/templates/base.html b/src/central/gui/templates/base.html index cc076ab..631c542 100644 --- a/src/central/gui/templates/base.html +++ b/src/central/gui/templates/base.html @@ -9,7 +9,36 @@ {% block head %}{% endblock %} +
+ {% if error %} +
+

{{ error }}

+
+ {% endif %} + {% if success %} +
+

{{ success }}

+
+ {% endif %} {% block content %}{% endblock %}
diff --git a/src/central/gui/templates/change_password.html b/src/central/gui/templates/change_password.html new file mode 100644 index 0000000..c353c60 --- /dev/null +++ b/src/central/gui/templates/change_password.html @@ -0,0 +1,36 @@ +{% extends "base.html" %} + +{% block title %}Central - Change Password{% endblock %} + +{% block content %} +
+
+

Change Password

+
+ +
+ + + + + + + + + +
+
+{% endblock %} diff --git a/src/central/gui/templates/login.html b/src/central/gui/templates/login.html new file mode 100644 index 0000000..3510bf8 --- /dev/null +++ b/src/central/gui/templates/login.html @@ -0,0 +1,29 @@ +{% extends "base.html" %} + +{% block title %}Central - Login{% endblock %} + +{% block content %} +
+
+

Login

+
+ +
+ + + + + + + +
+
+{% endblock %} diff --git a/src/central/gui/templates/setup.html b/src/central/gui/templates/setup.html new file mode 100644 index 0000000..7b72249 --- /dev/null +++ b/src/central/gui/templates/setup.html @@ -0,0 +1,37 @@ +{% extends "base.html" %} + +{% block title %}Central - Setup{% endblock %} + +{% block content %} +
+
+

Central First-Time Setup

+

Create the initial operator account to get started.

+
+ +
+ + + + + + + + + +
+
+{% endblock %} diff --git a/systemd/central-gui.service b/systemd/central-gui.service index 20967c1..3b08c31 100644 --- a/systemd/central-gui.service +++ b/systemd/central-gui.service @@ -9,6 +9,7 @@ User=central Group=central WorkingDirectory=/opt/central Environment=HOME=/opt/central +EnvironmentFile=/etc/central/central.env ExecStart=/opt/central/.venv/bin/central-gui Restart=on-failure RestartSec=5 @@ -18,3 +19,6 @@ ProtectHome=true PrivateTmp=true StandardOutput=journal StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ad93825 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +"""Shared fixtures for auth tests.""" + +import asyncio +import tempfile +from pathlib import Path +from typing import AsyncGenerator + +import asyncpg +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from central.bootstrap_config import Settings + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an event loop for the test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_settings(): + """Create mock settings for testing.""" + return Settings( + db_dsn="postgresql://test:test@localhost/test", + nats_url="nats://localhost:4222", + csrf_secret="test-csrf-secret-for-testing-only-32chars", + ) + + +@pytest.fixture +def mock_pool(): + """Create a mock database pool.""" + pool = MagicMock() + pool.acquire = MagicMock() + pool.close = AsyncMock() + return pool + + +@pytest.fixture +def mock_conn(): + """Create a mock database connection.""" + conn = MagicMock() + conn.fetchrow = AsyncMock() + conn.fetchval = AsyncMock() + conn.execute = AsyncMock() + return conn diff --git a/tests/test_audit.py b/tests/test_audit.py new file mode 100644 index 0000000..dbf2a30 --- /dev/null +++ b/tests/test_audit.py @@ -0,0 +1,92 @@ +"""Tests for audit log module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from central.gui.audit import ( + write_audit, + AUTH_LOGIN, + AUTH_LOGIN_FAILED, + AUTH_LOGOUT, + AUTH_PASSWORD_CHANGE, + OPERATOR_CREATE, +) + + +class TestAuditConstants: + """Tests for audit action constants.""" + + def test_auth_login(self): + assert AUTH_LOGIN == "auth.login" + + def test_auth_login_failed(self): + assert AUTH_LOGIN_FAILED == "auth.login_failed" + + def test_auth_logout(self): + assert AUTH_LOGOUT == "auth.logout" + + def test_auth_password_change(self): + assert AUTH_PASSWORD_CHANGE == "auth.password_change" + + def test_operator_create(self): + assert OPERATOR_CREATE == "operator.create" + + +class TestWriteAudit: + """Tests for write_audit function.""" + + @pytest.mark.asyncio + async def test_write_audit_basic(self): + """write_audit inserts basic audit record.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit(mock_conn, action="auth.login", operator_id=1) + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "INSERT INTO config.audit_log" in call_args[0][0] + + @pytest.mark.asyncio + async def test_write_audit_with_target(self): + """write_audit includes target when provided.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit( + mock_conn, + action="operator.create", + operator_id=1, + target="newuser", + ) + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + # target is the 3rd positional arg (after operator_id and action) + assert "newuser" in call_args[0] + + @pytest.mark.asyncio + async def test_write_audit_with_before_after(self): + """write_audit includes before/after when provided.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit( + mock_conn, + action="config.update", + operator_id=1, + before={"value": "old"}, + after={"value": "new"}, + ) + + mock_conn.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_write_audit_no_operator(self): + """write_audit works with operator_id=None.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit(mock_conn, action="auth.login_failed", operator_id=None) + + mock_conn.execute.assert_called_once() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..2ea9569 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,183 @@ +"""Tests for auth module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timezone + +from central.gui.auth import ( + hash_password, + verify_password, + validate_password, + generate_token, + create_session, + get_session, + delete_session, + get_operator_by_username, + create_operator, + Operator, +) + + +class TestPasswordHashing: + """Tests for password hashing functions.""" + + def test_hash_password_returns_string(self): + """hash_password returns a string.""" + result = hash_password("testpassword") + assert isinstance(result, str) + + def test_hash_password_includes_argon2id(self): + """hash_password uses argon2id algorithm.""" + result = hash_password("testpassword") + assert result.startswith("$argon2id$") + + def test_hash_password_different_each_time(self): + """hash_password produces different hashes for same password.""" + hash1 = hash_password("testpassword") + hash2 = hash_password("testpassword") + assert hash1 != hash2 + + def test_verify_password_correct(self): + """verify_password returns True for correct password.""" + password = "testpassword" + hashed = hash_password(password) + assert verify_password(password, hashed) is True + + def test_verify_password_incorrect(self): + """verify_password returns False for wrong password.""" + hashed = hash_password("testpassword") + assert verify_password("wrongpassword", hashed) is False + + def test_verify_password_empty(self): + """verify_password handles empty strings.""" + hashed = hash_password("testpassword") + assert verify_password("", hashed) is False + + +class TestPasswordValidation: + """Tests for password validation.""" + + def test_valid_password(self): + """validate_password passes for valid password.""" + validate_password("password123") # No exception + + def test_short_password(self): + """validate_password raises for short password.""" + with pytest.raises(ValueError) as exc_info: + validate_password("short") + assert "8 characters" in str(exc_info.value) + + +class TestTokenGeneration: + """Tests for token generation.""" + + def test_generate_token_length(self): + """generate_token produces expected length.""" + token = generate_token() + # URL-safe base64 of 32 bytes is 43 characters + assert len(token) == 43 + + def test_generate_token_unique(self): + """generate_token produces unique tokens.""" + tokens = [generate_token() for _ in range(100)] + assert len(set(tokens)) == 100 + + +class TestSessionManagement: + """Tests for session creation and retrieval.""" + + @pytest.mark.asyncio + async def test_create_session(self): + """create_session inserts a session record.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + token, expires_at = await create_session(mock_conn, operator_id=1, lifetime_days=90) + + assert len(token) == 43 + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "INSERT INTO config.sessions" in call_args[0][0] + + @pytest.mark.asyncio + async def test_get_session_found(self): + """get_session returns Operator when session exists.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "testuser", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + + operator = await get_session(mock_conn, "valid-token") + + assert operator is not None + assert operator.id == 1 + assert operator.username == "testuser" + + @pytest.mark.asyncio + async def test_get_session_not_found(self): + """get_session returns None when session doesn\'t exist.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) + + operator = await get_session(mock_conn, "invalid-token") + + assert operator is None + + @pytest.mark.asyncio + async def test_delete_session(self): + """delete_session removes the session.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await delete_session(mock_conn, "some-token") + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "DELETE FROM config.sessions" in call_args[0][0] + + +class TestOperatorManagement: + """Tests for operator creation and retrieval.""" + + @pytest.mark.asyncio + async def test_get_operator_by_username_found(self): + """get_operator_by_username returns operator when found.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "password_hash": "somehash", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + + result = await get_operator_by_username(mock_conn, "admin") + + assert result is not None + assert result["username"] == "admin" + + @pytest.mark.asyncio + async def test_get_operator_by_username_not_found(self): + """get_operator_by_username returns None when not found.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) + + result = await get_operator_by_username(mock_conn, "nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_create_operator(self): + """create_operator inserts and returns operator ID.""" + mock_conn = MagicMock() + mock_conn.fetchval = AsyncMock(return_value=1) + + operator_id = await create_operator(mock_conn, "newuser", "password123") + + assert operator_id == 1 + mock_conn.fetchval.assert_called_once() + call_args = mock_conn.fetchval.call_args + assert "INSERT INTO config.operators" in call_args[0][0] diff --git a/tests/test_session_auth.py b/tests/test_session_auth.py new file mode 100644 index 0000000..004e756 --- /dev/null +++ b/tests/test_session_auth.py @@ -0,0 +1,173 @@ +"""Tests for session authentication middleware.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime, timezone +from starlette.testclient import TestClient +from fastapi import FastAPI, Request + +from central.gui.middleware import SessionMiddleware +from central.gui.auth import Operator + + +class TestSessionMiddleware: + """Tests for SessionMiddleware.""" + + @pytest.mark.asyncio + async def test_no_cookie_sets_none_on_exempt_path(self): + """SessionMiddleware sets operator=None when no session cookie on exempt path.""" + mock_pool = MagicMock() + mock_pool.acquire = MagicMock() + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + app.add_middleware(SessionMiddleware) + client = TestClient(app) + + response = client.get("/health") + assert response.status_code == 200 + assert response.json()["operator"] is None + + @pytest.mark.asyncio + async def test_valid_cookie_sets_operator_on_exempt_path(self): + """SessionMiddleware sets operator when valid session cookie on exempt path.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + op = getattr(request.state, "operator", None) + if op: + return {"username": op.username} + return {"operator": None} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "valid-token"}) + + response = client.get("/health") + assert response.status_code == 200 + assert response.json()["username"] == "admin" + + @pytest.mark.asyncio + async def test_no_cookie_redirects_on_protected_path(self): + """SessionMiddleware redirects to /login when no cookie on protected path.""" + mock_pool = MagicMock() + mock_pool.acquire = MagicMock() + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + return {"message": "home"} + + @app.get("/login") + async def login(): + return {"message": "login"} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + @pytest.mark.asyncio + async def test_valid_cookie_allows_protected_path(self): + """SessionMiddleware allows protected path with valid session.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + op = request.state.operator + return {"message": "home", "user": op.username} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "valid-token"}) + + response = client.get("/") + assert response.status_code == 200 + assert response.json()["user"] == "admin" + + @pytest.mark.asyncio + async def test_invalid_cookie_redirects_on_protected_path(self): + """SessionMiddleware redirects when session is invalid/expired.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) # No session found + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + @app.get("/login") + async def login(): + return {"message": "login"} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "expired-token"}, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + @pytest.mark.asyncio + async def test_middleware_handles_db_error(self): + """SessionMiddleware handles database errors gracefully on exempt path.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(side_effect=Exception("DB error")) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "some-token"}) + + response = client.get("/health") + # Should not crash, just set operator to None + assert response.status_code == 200 + assert response.json()["operator"] is None diff --git a/tests/test_setup_gate.py b/tests/test_setup_gate.py new file mode 100644 index 0000000..a29fc39 --- /dev/null +++ b/tests/test_setup_gate.py @@ -0,0 +1,162 @@ +"""Tests for setup gate middleware.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.testclient import TestClient +from fastapi import FastAPI + +from central.gui.middleware import SetupGateMiddleware + + +class TestSetupGateMiddleware: + """Tests for SetupGateMiddleware.""" + + @pytest.mark.asyncio + async def test_allows_setup_route_when_incomplete(self): + """SetupGateMiddleware allows /setup when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup") + assert response.status_code == 200 + assert response.json() == {"message": "setup"} + + @pytest.mark.asyncio + async def test_allows_health_when_incomplete(self): + """SetupGateMiddleware allows /health regardless of setup state.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/health") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_other_routes_when_incomplete(self): + """SetupGateMiddleware redirects non-setup routes when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 307 + assert response.headers["location"] == "/setup" + + @pytest.mark.asyncio + async def test_allows_all_routes_when_complete(self): + """SetupGateMiddleware allows all routes when setup_complete=True.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "home"} + + @pytest.mark.asyncio + async def test_allows_static_when_incomplete(self): + """SetupGateMiddleware allows /static routes when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/static/test.css") + async def static(): + return "css" + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/static/test.css") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_setup_when_complete(self): + """SetupGateMiddleware redirects /setup to / when setup_complete=True.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup") + assert response.status_code == 302 + assert response.headers["location"] == "/" diff --git a/uv.lock b/uv.lock index 4a748b3..ee8cabc 100644 --- a/uv.lock +++ b/uv.lock @@ -89,6 +89,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/c1/93/44365f3d75053e53893ec6d733e4a5e3147502663554b4d864587c7828a7/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6", size = 81246, upload-time = "2025-07-30T10:01:54.145Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/72/70/7a2993a12b0ffa2a9271259b79cc616e2389ed1a4d93842fac5a1f923ffd/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d", size = 80343, upload-time = "2025-07-30T10:01:56.007Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/15777dfde1c29d96de7f18edf4cc94c385646852e7c7b0320aa91ccca583/argon2_cffi_bindings-25.1.0-cp39-abi3-win32.whl", hash = "sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2", size = 27180, upload-time = "2025-07-30T10:01:57.759Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/a759ece8f1829d1f162261226fbfd2c6832b3ff7657384045286d2afa384/argon2_cffi_bindings-25.1.0-cp39-abi3-win_amd64.whl", hash = "sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98", size = 31715, upload-time = "2025-07-30T10:01:58.56Z" }, + { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, +] + [[package]] name = "ast-serialize" version = "0.4.0" @@ -143,10 +176,12 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "aiohttp" }, + { name = "argon2-cffi" }, { name = "asyncpg" }, { name = "cloudevents" }, { name = "cryptography" }, { name = "fastapi" }, + { name = "fastapi-csrf-protect" }, { name = "jinja2" }, { name = "nats-py" }, { name = "pydantic" }, @@ -169,10 +204,12 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.13.5" }, + { name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "asyncpg", specifier = ">=0.31.0" }, { name = "cloudevents", specifier = ">=2.0.0" }, { name = "cryptography", specifier = ">=44.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, + { name = "fastapi-csrf-protect", specifier = ">=0.4.0" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "nats-py", specifier = ">=2.14.0" }, { name = "pydantic", specifier = ">=2,<3" }, @@ -325,6 +362,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/ff/2e4eca3ade2c22fe1dea7043b8ee9dabe47753349eb1b56a202de8af6349/fastapi-0.136.1-py3-none-any.whl", hash = "sha256:a6e9d7eeada96c93a4d69cb03836b44fa34e2854accb7244a1ece36cd4781c3f", size = 117683, upload-time = "2026-04-23T16:49:42.437Z" }, ] +[[package]] +name = "fastapi-csrf-protect" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "itsdangerous" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/1a/fedbcb4aba24ccc8abfb5d30e08112073c6a9f20b8d88adbdd3051ceedac/fastapi_csrf_protect-1.0.7.tar.gz", hash = "sha256:888b15b232625aae5b997fbcf81ef45633a7694f0312a054f1eec6d132b295fb", size = 207326, upload-time = "2025-09-16T07:06:08.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/10/f248aab919678444723d557da918088e5c737b44e03e3aa4a0ad7afc7dae/fastapi_csrf_protect-1.0.7-py3-none-any.whl", hash = "sha256:ca3c5b50564af932ac4ed3d06caeed61bf16eed13a31cfe2bdfc3f7c1e8612a3", size = 18412, upload-time = "2025-09-16T07:06:05.926Z" }, +] + [[package]] name = "frozenlist" version = "1.8.0" @@ -420,6 +472,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jinja2" version = "3.1.6"