From f059f982bcd459864e60e2e601a68ea69ec5cc53 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 17 May 2026 05:30:49 +0000 Subject: [PATCH 1/6] feat(gui): add auth core, setup gate, and first-run operator creation - Add migrations 007-010 for system config, operators, sessions, audit_log - Implement argon2id password hashing via argon2-cffi - Implement session-based authentication with database-stored tokens - Add SetupGateMiddleware to redirect to /setup until first operator created - Add SessionMiddleware to load session from cookie and attach operator - Create /setup, /login, /logout, /change-password routes with CSRF protection - Add periodic session cleanup task (hourly) - Add audit logging for auth events - Update systemd unit with EnvironmentFile for /etc/central/central.env - Add comprehensive tests for auth, middleware, and audit modules Co-Authored-By: Claude Opus 4.5 --- .gitignore | 1 + docs/environment.md | 28 ++ pyproject.toml | 2 + sql/migrations/007_add_config_system.sql | 21 ++ sql/migrations/008_add_operators.sql | 10 + sql/migrations/009_add_sessions.sql | 11 + sql/migrations/010_add_audit_log.sql | 15 + src/central/bootstrap_config.py | 3 + src/central/gui/__init__.py | 156 +++++++- src/central/gui/audit.py | 37 ++ src/central/gui/auth.py | 138 +++++++ src/central/gui/db.py | 48 +++ src/central/gui/middleware.py | 96 +++++ src/central/gui/routes.py | 352 +++++++++++++++++- src/central/gui/templates/base.html | 29 ++ .../gui/templates/change_password.html | 36 ++ src/central/gui/templates/login.html | 29 ++ src/central/gui/templates/setup.html | 37 ++ systemd/central-gui.service | 4 + tests/conftest.py | 50 +++ tests/test_audit.py | 92 +++++ tests/test_auth.py | 183 +++++++++ tests/test_session_auth.py | 173 +++++++++ tests/test_setup_gate.py | 162 ++++++++ uv.lock | 61 +++ 25 files changed, 1758 insertions(+), 16 deletions(-) create mode 100644 sql/migrations/007_add_config_system.sql create mode 100644 sql/migrations/008_add_operators.sql create mode 100644 sql/migrations/009_add_sessions.sql create mode 100644 sql/migrations/010_add_audit_log.sql create mode 100644 src/central/gui/audit.py create mode 100644 src/central/gui/auth.py create mode 100644 src/central/gui/db.py create mode 100644 src/central/gui/middleware.py create mode 100644 src/central/gui/templates/change_password.html create mode 100644 src/central/gui/templates/login.html create mode 100644 src/central/gui/templates/setup.html create mode 100644 tests/conftest.py create mode 100644 tests/test_audit.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_session_auth.py create mode 100644 tests/test_setup_gate.py diff --git a/.gitignore b/.gitignore index a61310f..04e7f06 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ db.env .vscode/ *.swp .DS_Store +.ssh/ diff --git a/docs/environment.md b/docs/environment.md index 7396362..9659443 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -68,6 +68,34 @@ journalctl -u central-archive -f ## Database +## Environment Variables + +Environment variables are stored in `/etc/central/central.env` and loaded by +systemd services via `EnvironmentFile=`. + +| Variable | Required | Description | +|----------|----------|-------------| +| `CENTRAL_CSRF_SECRET` | Yes (for GUI) | Secret key for CSRF token signing. Generate with `python3 -c "import secrets; print(secrets.token_urlsafe(32))"` | + +### Generating CSRF Secret + +```bash +python3 -c "import secrets; print(secrets.token_urlsafe(32))" +``` + +Add the generated value to `/etc/central/central.env`: + +```bash +CENTRAL_CSRF_SECRET= +``` + +Ensure the file has restricted permissions: + +```bash +sudo chmod 640 /etc/central/central.env +sudo chown central:central /etc/central/central.env +``` + PostgreSQL 16 with TimescaleDB runs on CT104: ```bash diff --git a/pyproject.toml b/pyproject.toml index 476998b..d1d7f8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,11 @@ license = {text = "MIT"} authors = [{name = "Matt Johnson"}] dependencies = [ "aiohttp>=3.13.5", + "argon2-cffi>=25.1.0", "asyncpg>=0.31.0", "cloudevents>=2.0.0", "cryptography>=44.0.0", + "fastapi-csrf-protect>=0.4.0", "fastapi>=0.115.0", "jinja2>=3.1.6", "nats-py>=2.14.0", diff --git a/sql/migrations/007_add_config_system.sql b/sql/migrations/007_add_config_system.sql new file mode 100644 index 0000000..6c0b80b --- /dev/null +++ b/sql/migrations/007_add_config_system.sql @@ -0,0 +1,21 @@ +-- Migration 007: Add config.system table for global settings +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.system ( + id BOOLEAN PRIMARY KEY DEFAULT true CHECK (id = true), + setup_complete BOOLEAN NOT NULL DEFAULT false, + session_lifetime_days INTEGER NOT NULL DEFAULT 90, + map_tile_url TEXT NOT NULL DEFAULT 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', + map_attribution TEXT NOT NULL DEFAULT '© OpenStreetMap contributors', + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Reuse existing set_updated_at trigger function +DROP TRIGGER IF EXISTS system_set_updated_at ON config.system; +CREATE TRIGGER system_set_updated_at + BEFORE UPDATE ON config.system + FOR EACH ROW + EXECUTE FUNCTION config.set_updated_at(); + +-- Seed single row +INSERT INTO config.system (id) VALUES (true) ON CONFLICT DO NOTHING; diff --git a/sql/migrations/008_add_operators.sql b/sql/migrations/008_add_operators.sql new file mode 100644 index 0000000..2886caa --- /dev/null +++ b/sql/migrations/008_add_operators.sql @@ -0,0 +1,10 @@ +-- Migration 008: Add config.operators table for user accounts +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.operators ( + id BIGSERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + password_changed_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/sql/migrations/009_add_sessions.sql b/sql/migrations/009_add_sessions.sql new file mode 100644 index 0000000..9016676 --- /dev/null +++ b/sql/migrations/009_add_sessions.sql @@ -0,0 +1,11 @@ +-- Migration 009: Add config.sessions table for auth tokens +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.sessions ( + token TEXT PRIMARY KEY, + operator_id BIGINT NOT NULL REFERENCES config.operators(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + expires_at TIMESTAMPTZ NOT NULL +); + +CREATE INDEX IF NOT EXISTS sessions_expires_at_idx ON config.sessions(expires_at); diff --git a/sql/migrations/010_add_audit_log.sql b/sql/migrations/010_add_audit_log.sql new file mode 100644 index 0000000..d6baa65 --- /dev/null +++ b/sql/migrations/010_add_audit_log.sql @@ -0,0 +1,15 @@ +-- Migration 010: Add config.audit_log table +-- Idempotent per docs/migrations.md + +CREATE TABLE IF NOT EXISTS config.audit_log ( + id BIGSERIAL PRIMARY KEY, + ts TIMESTAMPTZ NOT NULL DEFAULT now(), + operator_id BIGINT REFERENCES config.operators(id) ON DELETE SET NULL, + action TEXT NOT NULL, + target TEXT, + before JSONB, + after JSONB +); + +CREATE INDEX IF NOT EXISTS audit_log_ts_idx ON config.audit_log(ts DESC); +CREATE INDEX IF NOT EXISTS audit_log_action_idx ON config.audit_log(action); diff --git a/src/central/bootstrap_config.py b/src/central/bootstrap_config.py index f5e46bc..334c4c6 100644 --- a/src/central/bootstrap_config.py +++ b/src/central/bootstrap_config.py @@ -33,6 +33,9 @@ class Settings(BaseSettings): default="INFO", description="Logging level", ) + csrf_secret: str = Field( + description="Secret key for CSRF token signing (generate with: python -c \"import secrets; print(secrets.token_urlsafe(32))\")", + ) @lru_cache diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 807bae7..f7ee746 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -1,13 +1,17 @@ """Central GUI — FastAPI + Jinja2 + HTMX.""" +import asyncio +import logging +from contextlib import asynccontextmanager from pathlib import Path +from typing import Any import uvicorn from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from central.gui.routes import router +logger = logging.getLogger(__name__) # Template and static directories GUI_DIR = Path(__file__).parent @@ -17,17 +21,108 @@ STATIC_DIR = GUI_DIR / "static" # Jinja2 templates instance (shared with routes) templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) +# Shutdown event and cleanup task +_shutdown_event: asyncio.Event | None = None +_cleanup_task: asyncio.Task | None = None + +# Lazy app singleton +_app: FastAPI | None = None + + +def _configure_csrf() -> None: + """Configure CSRF protection. Must be called before app starts.""" + from fastapi_csrf_protect import CsrfProtect + from pydantic import BaseModel + from central.bootstrap_config import get_settings + + class CsrfSettings(BaseModel): + secret_key: str + + @CsrfProtect.load_config + def get_csrf_config(): + settings = get_settings() + return CsrfSettings(secret_key=settings.csrf_secret) + + +async def _session_cleanup_loop() -> None: + """Periodically clean up expired sessions.""" + global _shutdown_event + + from central.gui.db import get_pool + + if _shutdown_event is None: + return + + while not _shutdown_event.is_set(): + try: + await asyncio.wait_for(_shutdown_event.wait(), timeout=3600) + except asyncio.TimeoutError: + try: + pool = get_pool() + if pool: + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM config.sessions WHERE expires_at < now()" + ) + deleted = result.split()[-1] if result else "0" + if int(deleted) > 0: + logger.info("Session cleanup", extra={"deleted": deleted}) + except Exception: + logger.warning("Session cleanup failed", exc_info=True) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + global _shutdown_event, _cleanup_task + + from central.bootstrap_config import get_settings + from central.gui.db import close_pool, init_pool + + settings = get_settings() + + # Initialize database pool + await init_pool(settings.db_dsn) + + # Start session cleanup task + _shutdown_event = asyncio.Event() + _cleanup_task = asyncio.create_task(_session_cleanup_loop()) + + logger.info("Central GUI started") + + yield + + # Shutdown + if _shutdown_event: + _shutdown_event.set() + if _cleanup_task: + try: + await asyncio.wait_for(_cleanup_task, timeout=5.0) + except asyncio.TimeoutError: + _cleanup_task.cancel() + + await close_pool() + logger.info("Central GUI stopped") + + +def _create_app() -> FastAPI: + """Create the FastAPI application.""" + from central.gui.middleware import SessionMiddleware, SetupGateMiddleware + from central.gui.routes import router + + # Configure CSRF before creating app + _configure_csrf() -def create_app() -> FastAPI: - """Create and configure the FastAPI application.""" app = FastAPI( - title="Central", - description="Central Data Hub GUI", - docs_url=None, # Disable Swagger UI for now - redoc_url=None, # Disable ReDoc for now + title="Central GUI", + lifespan=lifespan, ) - # Mount static files if directory exists and has content + # Add middleware (order matters - first added runs last) + app.add_middleware(SessionMiddleware) + app.add_middleware(SetupGateMiddleware) + + # Mount static files if STATIC_DIR.exists(): app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") @@ -37,10 +132,47 @@ def create_app() -> FastAPI: return app -# Application instance -app = create_app() +def __getattr__(name: str) -> Any: + """Lazy attribute access for app singleton.""" + global _app + if name == "app": + if _app is None: + _app = _create_app() + return _app + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def main() -> None: - """Entry point for central-gui console script.""" - uvicorn.run(app, host="127.0.0.1", port=8000) + """Entry point for central-gui command.""" + import logging.config + + logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s %(levelname)s %(name)s: %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "root": { + "level": "INFO", + "handlers": ["console"], + }, + }) + + uvicorn.run( + "central.gui:app", + host="0.0.0.0", + port=8088, + reload=False, + ) + + +if __name__ == "__main__": + main() diff --git a/src/central/gui/audit.py b/src/central/gui/audit.py new file mode 100644 index 0000000..428275a --- /dev/null +++ b/src/central/gui/audit.py @@ -0,0 +1,37 @@ +"""Audit logging for Central GUI.""" + +import json +from typing import Any + +# Audit action constants +AUTH_LOGIN = "auth.login" +AUTH_LOGIN_FAILED = "auth.login_failed" +AUTH_LOGOUT = "auth.logout" +AUTH_PASSWORD_CHANGE = "auth.password_change" +OPERATOR_CREATE = "operator.create" + + +async def write_audit( + conn: Any, # asyncpg.Connection + action: str, + operator_id: int | None = None, + target: str | None = None, + before: dict[str, Any] | None = None, + after: dict[str, Any] | None = None, +) -> None: + """Write an audit log entry.""" + # Serialize before/after as JSON strings if provided + before_json = json.dumps(before) if before else None + after_json = json.dumps(after) if after else None + + await conn.execute( + """ + INSERT INTO config.audit_log (operator_id, action, target, before, after) + VALUES ($1, $2, $3, $4::jsonb, $5::jsonb) + """, + operator_id, + action, + target, + before_json, + after_json, + ) diff --git a/src/central/gui/auth.py b/src/central/gui/auth.py new file mode 100644 index 0000000..3b74ac0 --- /dev/null +++ b/src/central/gui/auth.py @@ -0,0 +1,138 @@ +"""Authentication utilities for Central GUI.""" + +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +from argon2 import PasswordHasher +from argon2.exceptions import VerifyMismatchError + +# Use argon2-cffi defaults (argon2id) +_hasher = PasswordHasher() + + +@dataclass +class Operator: + """Operator account.""" + id: int + username: str + created_at: datetime + password_changed_at: datetime | None = None + + +def hash_password(plain: str) -> str: + """Hash a password using argon2id.""" + return _hasher.hash(plain) + + +def verify_password(plain: str, hashed: str) -> bool: + """Verify a password against its hash.""" + try: + _hasher.verify(hashed, plain) + return True + except VerifyMismatchError: + return False + + +def validate_password(plain: str) -> None: + """Validate password meets requirements. Raises ValueError if invalid.""" + if len(plain) < 8: + raise ValueError("Password must be at least 8 characters") + + +def generate_token() -> str: + """Generate a cryptographically secure session token.""" + return secrets.token_urlsafe(32) + + +async def create_session( + conn: Any, # asyncpg.Connection + operator_id: int, + lifetime_days: int, +) -> tuple[str, datetime]: + """Create a new session for an operator. + + Returns (token, expires_at). + """ + token = generate_token() + expires_at = datetime.now(timezone.utc) + timedelta(days=lifetime_days) + + await conn.execute( + """ + INSERT INTO config.sessions (token, operator_id, expires_at) + VALUES ($1, $2, $3) + """, + token, + operator_id, + expires_at, + ) + + return token, expires_at + + +async def get_session(conn: Any, token: str) -> Operator | None: + """Look up a session and return the associated operator. + + Returns None if token is invalid or expired. + """ + row = await conn.fetchrow( + """ + SELECT o.id, o.username, o.created_at, o.password_changed_at + FROM config.sessions s + JOIN config.operators o ON s.operator_id = o.id + WHERE s.token = $1 AND s.expires_at > now() + """, + token, + ) + + if row is None: + return None + + return Operator( + id=row["id"], + username=row["username"], + created_at=row["created_at"], + password_changed_at=row.get("password_changed_at"), + ) + + +async def delete_session(conn: Any, token: str) -> None: + """Delete a session.""" + await conn.execute( + "DELETE FROM config.sessions WHERE token = $1", + token, + ) + + +async def get_operator_by_username(conn: Any, username: str) -> dict | None: + """Get an operator by username. + + Returns the row dict or None if not found. + """ + return await conn.fetchrow( + """ + SELECT id, username, password_hash, created_at, password_changed_at + FROM config.operators + WHERE username = $1 + """, + username, + ) + + +async def create_operator(conn: Any, username: str, password: str) -> int: + """Create a new operator. + + Returns the new operator ID. + """ + password_hash = hash_password(password) + row = await conn.fetchval( + """ + INSERT INTO config.operators (username, password_hash) + VALUES ($1, $2) + RETURNING id + """, + username, + password_hash, + ) + return row diff --git a/src/central/gui/db.py b/src/central/gui/db.py new file mode 100644 index 0000000..645fdc7 --- /dev/null +++ b/src/central/gui/db.py @@ -0,0 +1,48 @@ +"""Database connection pool for GUI.""" + +import json +from typing import Any + +import asyncpg + +# Module-level pool instance +_pool: asyncpg.Pool | None = None + + +# TODO: Deduplicate with central.config_store._setup_json_codec +async def _setup_json_codec(conn: asyncpg.Connection) -> None: + """Set up JSON codec for asyncpg connection.""" + await conn.set_type_codec( + "jsonb", + encoder=json.dumps, + decoder=json.loads, + schema="pg_catalog", + ) + + +async def init_pool(dsn: str) -> asyncpg.Pool: + """Initialize the connection pool.""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool( + dsn, + min_size=1, + max_size=5, + init=_setup_json_codec, + ) + return _pool + + +def get_pool() -> asyncpg.Pool: + """Get the connection pool. Must call init_pool first.""" + if _pool is None: + raise RuntimeError("Database pool not initialized. Call init_pool first.") + return _pool + + +async def close_pool() -> None: + """Close the connection pool.""" + global _pool + if _pool is not None: + await _pool.close() + _pool = None diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py new file mode 100644 index 0000000..e451528 --- /dev/null +++ b/src/central/gui/middleware.py @@ -0,0 +1,96 @@ +"""Middleware for Central GUI.""" + +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from central.gui.auth import get_session +from central.gui.db import get_pool + +logger = logging.getLogger(__name__) + +# Paths that don't require setup to be complete +SETUP_EXEMPT_PATHS = {"/setup", "/health"} +SETUP_EXEMPT_PREFIXES = ("/static/",) + +# Paths that don't require authentication +AUTH_EXEMPT_PATHS = {"/setup", "/login", "/health"} +AUTH_EXEMPT_PREFIXES = ("/static/",) + + +def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: + """Check if a path is exempt from a check.""" + if path in exempt_paths: + return True + for prefix in exempt_prefixes: + if path.startswith(prefix): + return True + return False + + +class SetupGateMiddleware(BaseHTTPMiddleware): + """Redirect to /setup if setup is not complete.""" + + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + + # Check setup status from database + pool = get_pool() + if pool is None: + # Pool not initialized yet + return await call_next(request) + + setup_complete = False + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT setup_complete FROM config.system WHERE id = true" + ) + setup_complete = row["setup_complete"] if row else False + except Exception: + logger.warning("Failed to check setup status", exc_info=True) + # On error, allow the request through + return await call_next(request) + + if not setup_complete: + # Setup not complete - only allow exempt paths + if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES): + return RedirectResponse(url="/setup", status_code=307) + else: + # Setup complete - redirect /setup to / + if path == "/setup": + return RedirectResponse(url="/", status_code=302) + + return await call_next(request) + + +class SessionMiddleware(BaseHTTPMiddleware): + """Load session from cookie and attach operator to request.state.""" + + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + + # Initialize operator to None + request.state.operator = None + + # Try to load session from cookie + session_token = request.cookies.get("central_session") + if session_token: + pool = get_pool() + if pool is not None: + try: + async with pool.acquire() as conn: + operator = await get_session(conn, session_token) + request.state.operator = operator + except Exception: + logger.warning("Failed to load session", exc_info=True) + request.state.operator = None + + # Check if auth is required + if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): + if request.state.operator is None: + return RedirectResponse(url="/login", status_code=302) + + return await call_next(request) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index e06ee2d..19a8b6d 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -1,11 +1,60 @@ """Route handlers for Central GUI.""" -from fastapi import APIRouter, Request -from fastapi.responses import HTMLResponse +from fastapi import APIRouter, Depends, Form, Request +from fastapi.responses import HTMLResponse, RedirectResponse, Response +from fastapi_csrf_protect import CsrfProtect + +from central.gui.auth import ( + create_session, + delete_session, + hash_password, + validate_password, + verify_password, +) +from central.gui.audit import ( + AUTH_LOGIN, + AUTH_LOGIN_FAILED, + AUTH_LOGOUT, + AUTH_PASSWORD_CHANGE, + OPERATOR_CREATE, + write_audit, +) +from central.gui.db import get_pool router = APIRouter() +def _get_templates(): + """Get templates instance (deferred import to avoid circular).""" + from central.gui import templates + return templates + + +def _set_session_cookie( + response: Response, + token: str, + max_age: int, +) -> None: + """Set the session cookie on a response.""" + response.set_cookie( + key="central_session", + value=token, + httponly=True, + samesite="lax", + secure=False, + max_age=max_age, + path="/", + ) + + +def _clear_session_cookie(response: Response) -> None: + """Clear the session cookie.""" + response.delete_cookie( + key="central_session", + path="/", + ) + + @router.get("/health") async def health() -> dict: """Health check endpoint.""" @@ -15,9 +64,304 @@ async def health() -> dict: @router.get("/", response_class=HTMLResponse) async def index(request: Request) -> HTMLResponse: """Render the index page.""" - from central.gui import templates - + templates = _get_templates() return templates.TemplateResponse( request=request, name="index.html", ) + + +@router.get("/setup", response_class=HTMLResponse) +async def setup_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the setup form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup.html", + context={"csrf_token": signed_token, "error": None}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup") +async def setup_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + confirm_password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the setup form.""" + templates = _get_templates() + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Validate input + error = None + if password != confirm_password: + error = "Passwords do not match" + else: + try: + validate_password(password) + except ValueError as e: + error = str(e) + + if error: + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup.html", + context={"csrf_token": signed_token, "error": error}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Create operator + password_hash = hash_password(password) + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO config.operators (username, password_hash) + VALUES ($1, $2) + RETURNING id + """, + username, + password_hash, + ) + operator_id = row["id"] + + # Write audit log + await write_audit( + conn, + OPERATOR_CREATE, + operator_id=operator_id, + target=username, + ) + + # Get session lifetime + sysrow = await conn.fetchrow( + "SELECT session_lifetime_days FROM config.system WHERE id = true" + ) + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + + # Create session + token, expires_at = await create_session(conn, operator_id, lifetime_days) + + # Mark setup complete + await conn.execute( + "UPDATE config.system SET setup_complete = true WHERE id = true" + ) + + # Redirect with session cookie + response = RedirectResponse(url="/", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.get("/login", response_class=HTMLResponse) +async def login_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the login form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": None}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/login") +async def login_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the login form.""" + templates = _get_templates() + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Look up operator + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, username, password_hash, created_at, password_changed_at + FROM config.operators + WHERE username = $1 + """, + username, + ) + + if row is None: + # Unknown user - still audit the attempt + await write_audit(conn, AUTH_LOGIN_FAILED, target=username) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": "Invalid username or password"}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Verify password + if not verify_password(password, row["password_hash"]): + await write_audit(conn, AUTH_LOGIN_FAILED, operator_id=row["id"], target=username) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": signed_token, "error": "Invalid username or password"}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Get session lifetime + sysrow = await conn.fetchrow( + "SELECT session_lifetime_days FROM config.system WHERE id = true" + ) + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + + # Create session + token, expires_at = await create_session(conn, row["id"], lifetime_days) + + # Audit login + await write_audit(conn, AUTH_LOGIN, operator_id=row["id"], target=username) + + # Redirect with session cookie + response = RedirectResponse(url="/", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.post("/logout") +async def logout( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Log out the current user.""" + pool = get_pool() + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Get current session + session_token = request.cookies.get("central_session") + operator = getattr(request.state, "operator", None) + + async with pool.acquire() as conn: + if session_token: + await delete_session(conn, session_token) + + if operator: + await write_audit(conn, AUTH_LOGOUT, operator_id=operator.id, target=operator.username) + + response = RedirectResponse(url="/login", status_code=302) + _clear_session_cookie(response) + return response + + +@router.get("/change-password", response_class=HTMLResponse) +async def change_password_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the change password form.""" + templates = _get_templates() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="change_password.html", + context={"csrf_token": signed_token, "error": None, "success": False}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/change-password") +async def change_password_submit( + request: Request, + current_password: str = Form(...), + new_password: str = Form(...), + confirm_password: str = Form(...), + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the change password form.""" + templates = _get_templates() + pool = get_pool() + operator = request.state.operator + + # Validate CSRF + await csrf_protect.validate_csrf(request) + + # Get current password hash + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT password_hash FROM config.operators WHERE id = $1", + operator.id, + ) + + error = None + + # Verify current password + if not verify_password(current_password, row["password_hash"]): + error = "Current password is incorrect" + elif new_password != confirm_password: + error = "New passwords do not match" + else: + try: + validate_password(new_password) + except ValueError as e: + error = str(e) + + if error: + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="change_password.html", + context={"csrf_token": signed_token, "error": error, "success": False}, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Update password + new_hash = hash_password(new_password) + await conn.execute( + """ + UPDATE config.operators + SET password_hash = $1, password_changed_at = now() + WHERE id = $2 + """, + new_hash, + operator.id, + ) + + # Audit + await write_audit( + conn, + AUTH_PASSWORD_CHANGE, + operator_id=operator.id, + target=operator.username, + ) + + # Redirect to index + return RedirectResponse(url="/", status_code=302) diff --git a/src/central/gui/templates/base.html b/src/central/gui/templates/base.html index cc076ab..631c542 100644 --- a/src/central/gui/templates/base.html +++ b/src/central/gui/templates/base.html @@ -9,7 +9,36 @@ {% block head %}{% endblock %} +
+ {% if error %} +
+

{{ error }}

+
+ {% endif %} + {% if success %} +
+

{{ success }}

+
+ {% endif %} {% block content %}{% endblock %}
diff --git a/src/central/gui/templates/change_password.html b/src/central/gui/templates/change_password.html new file mode 100644 index 0000000..c353c60 --- /dev/null +++ b/src/central/gui/templates/change_password.html @@ -0,0 +1,36 @@ +{% extends "base.html" %} + +{% block title %}Central - Change Password{% endblock %} + +{% block content %} +
+
+

Change Password

+
+ +
+ + + + + + + + + +
+
+{% endblock %} diff --git a/src/central/gui/templates/login.html b/src/central/gui/templates/login.html new file mode 100644 index 0000000..3510bf8 --- /dev/null +++ b/src/central/gui/templates/login.html @@ -0,0 +1,29 @@ +{% extends "base.html" %} + +{% block title %}Central - Login{% endblock %} + +{% block content %} +
+
+

Login

+
+ +
+ + + + + + + +
+
+{% endblock %} diff --git a/src/central/gui/templates/setup.html b/src/central/gui/templates/setup.html new file mode 100644 index 0000000..7b72249 --- /dev/null +++ b/src/central/gui/templates/setup.html @@ -0,0 +1,37 @@ +{% extends "base.html" %} + +{% block title %}Central - Setup{% endblock %} + +{% block content %} +
+
+

Central First-Time Setup

+

Create the initial operator account to get started.

+
+ +
+ + + + + + + + + +
+
+{% endblock %} diff --git a/systemd/central-gui.service b/systemd/central-gui.service index 20967c1..3b08c31 100644 --- a/systemd/central-gui.service +++ b/systemd/central-gui.service @@ -9,6 +9,7 @@ User=central Group=central WorkingDirectory=/opt/central Environment=HOME=/opt/central +EnvironmentFile=/etc/central/central.env ExecStart=/opt/central/.venv/bin/central-gui Restart=on-failure RestartSec=5 @@ -18,3 +19,6 @@ ProtectHome=true PrivateTmp=true StandardOutput=journal StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ad93825 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +"""Shared fixtures for auth tests.""" + +import asyncio +import tempfile +from pathlib import Path +from typing import AsyncGenerator + +import asyncpg +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from central.bootstrap_config import Settings + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an event loop for the test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_settings(): + """Create mock settings for testing.""" + return Settings( + db_dsn="postgresql://test:test@localhost/test", + nats_url="nats://localhost:4222", + csrf_secret="test-csrf-secret-for-testing-only-32chars", + ) + + +@pytest.fixture +def mock_pool(): + """Create a mock database pool.""" + pool = MagicMock() + pool.acquire = MagicMock() + pool.close = AsyncMock() + return pool + + +@pytest.fixture +def mock_conn(): + """Create a mock database connection.""" + conn = MagicMock() + conn.fetchrow = AsyncMock() + conn.fetchval = AsyncMock() + conn.execute = AsyncMock() + return conn diff --git a/tests/test_audit.py b/tests/test_audit.py new file mode 100644 index 0000000..dbf2a30 --- /dev/null +++ b/tests/test_audit.py @@ -0,0 +1,92 @@ +"""Tests for audit log module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from central.gui.audit import ( + write_audit, + AUTH_LOGIN, + AUTH_LOGIN_FAILED, + AUTH_LOGOUT, + AUTH_PASSWORD_CHANGE, + OPERATOR_CREATE, +) + + +class TestAuditConstants: + """Tests for audit action constants.""" + + def test_auth_login(self): + assert AUTH_LOGIN == "auth.login" + + def test_auth_login_failed(self): + assert AUTH_LOGIN_FAILED == "auth.login_failed" + + def test_auth_logout(self): + assert AUTH_LOGOUT == "auth.logout" + + def test_auth_password_change(self): + assert AUTH_PASSWORD_CHANGE == "auth.password_change" + + def test_operator_create(self): + assert OPERATOR_CREATE == "operator.create" + + +class TestWriteAudit: + """Tests for write_audit function.""" + + @pytest.mark.asyncio + async def test_write_audit_basic(self): + """write_audit inserts basic audit record.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit(mock_conn, action="auth.login", operator_id=1) + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "INSERT INTO config.audit_log" in call_args[0][0] + + @pytest.mark.asyncio + async def test_write_audit_with_target(self): + """write_audit includes target when provided.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit( + mock_conn, + action="operator.create", + operator_id=1, + target="newuser", + ) + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + # target is the 3rd positional arg (after operator_id and action) + assert "newuser" in call_args[0] + + @pytest.mark.asyncio + async def test_write_audit_with_before_after(self): + """write_audit includes before/after when provided.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit( + mock_conn, + action="config.update", + operator_id=1, + before={"value": "old"}, + after={"value": "new"}, + ) + + mock_conn.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_write_audit_no_operator(self): + """write_audit works with operator_id=None.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await write_audit(mock_conn, action="auth.login_failed", operator_id=None) + + mock_conn.execute.assert_called_once() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..2ea9569 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,183 @@ +"""Tests for auth module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timezone + +from central.gui.auth import ( + hash_password, + verify_password, + validate_password, + generate_token, + create_session, + get_session, + delete_session, + get_operator_by_username, + create_operator, + Operator, +) + + +class TestPasswordHashing: + """Tests for password hashing functions.""" + + def test_hash_password_returns_string(self): + """hash_password returns a string.""" + result = hash_password("testpassword") + assert isinstance(result, str) + + def test_hash_password_includes_argon2id(self): + """hash_password uses argon2id algorithm.""" + result = hash_password("testpassword") + assert result.startswith("$argon2id$") + + def test_hash_password_different_each_time(self): + """hash_password produces different hashes for same password.""" + hash1 = hash_password("testpassword") + hash2 = hash_password("testpassword") + assert hash1 != hash2 + + def test_verify_password_correct(self): + """verify_password returns True for correct password.""" + password = "testpassword" + hashed = hash_password(password) + assert verify_password(password, hashed) is True + + def test_verify_password_incorrect(self): + """verify_password returns False for wrong password.""" + hashed = hash_password("testpassword") + assert verify_password("wrongpassword", hashed) is False + + def test_verify_password_empty(self): + """verify_password handles empty strings.""" + hashed = hash_password("testpassword") + assert verify_password("", hashed) is False + + +class TestPasswordValidation: + """Tests for password validation.""" + + def test_valid_password(self): + """validate_password passes for valid password.""" + validate_password("password123") # No exception + + def test_short_password(self): + """validate_password raises for short password.""" + with pytest.raises(ValueError) as exc_info: + validate_password("short") + assert "8 characters" in str(exc_info.value) + + +class TestTokenGeneration: + """Tests for token generation.""" + + def test_generate_token_length(self): + """generate_token produces expected length.""" + token = generate_token() + # URL-safe base64 of 32 bytes is 43 characters + assert len(token) == 43 + + def test_generate_token_unique(self): + """generate_token produces unique tokens.""" + tokens = [generate_token() for _ in range(100)] + assert len(set(tokens)) == 100 + + +class TestSessionManagement: + """Tests for session creation and retrieval.""" + + @pytest.mark.asyncio + async def test_create_session(self): + """create_session inserts a session record.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + token, expires_at = await create_session(mock_conn, operator_id=1, lifetime_days=90) + + assert len(token) == 43 + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "INSERT INTO config.sessions" in call_args[0][0] + + @pytest.mark.asyncio + async def test_get_session_found(self): + """get_session returns Operator when session exists.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "testuser", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + + operator = await get_session(mock_conn, "valid-token") + + assert operator is not None + assert operator.id == 1 + assert operator.username == "testuser" + + @pytest.mark.asyncio + async def test_get_session_not_found(self): + """get_session returns None when session doesn\'t exist.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) + + operator = await get_session(mock_conn, "invalid-token") + + assert operator is None + + @pytest.mark.asyncio + async def test_delete_session(self): + """delete_session removes the session.""" + mock_conn = MagicMock() + mock_conn.execute = AsyncMock() + + await delete_session(mock_conn, "some-token") + + mock_conn.execute.assert_called_once() + call_args = mock_conn.execute.call_args + assert "DELETE FROM config.sessions" in call_args[0][0] + + +class TestOperatorManagement: + """Tests for operator creation and retrieval.""" + + @pytest.mark.asyncio + async def test_get_operator_by_username_found(self): + """get_operator_by_username returns operator when found.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "password_hash": "somehash", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + + result = await get_operator_by_username(mock_conn, "admin") + + assert result is not None + assert result["username"] == "admin" + + @pytest.mark.asyncio + async def test_get_operator_by_username_not_found(self): + """get_operator_by_username returns None when not found.""" + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) + + result = await get_operator_by_username(mock_conn, "nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_create_operator(self): + """create_operator inserts and returns operator ID.""" + mock_conn = MagicMock() + mock_conn.fetchval = AsyncMock(return_value=1) + + operator_id = await create_operator(mock_conn, "newuser", "password123") + + assert operator_id == 1 + mock_conn.fetchval.assert_called_once() + call_args = mock_conn.fetchval.call_args + assert "INSERT INTO config.operators" in call_args[0][0] diff --git a/tests/test_session_auth.py b/tests/test_session_auth.py new file mode 100644 index 0000000..004e756 --- /dev/null +++ b/tests/test_session_auth.py @@ -0,0 +1,173 @@ +"""Tests for session authentication middleware.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime, timezone +from starlette.testclient import TestClient +from fastapi import FastAPI, Request + +from central.gui.middleware import SessionMiddleware +from central.gui.auth import Operator + + +class TestSessionMiddleware: + """Tests for SessionMiddleware.""" + + @pytest.mark.asyncio + async def test_no_cookie_sets_none_on_exempt_path(self): + """SessionMiddleware sets operator=None when no session cookie on exempt path.""" + mock_pool = MagicMock() + mock_pool.acquire = MagicMock() + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + app.add_middleware(SessionMiddleware) + client = TestClient(app) + + response = client.get("/health") + assert response.status_code == 200 + assert response.json()["operator"] is None + + @pytest.mark.asyncio + async def test_valid_cookie_sets_operator_on_exempt_path(self): + """SessionMiddleware sets operator when valid session cookie on exempt path.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + op = getattr(request.state, "operator", None) + if op: + return {"username": op.username} + return {"operator": None} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "valid-token"}) + + response = client.get("/health") + assert response.status_code == 200 + assert response.json()["username"] == "admin" + + @pytest.mark.asyncio + async def test_no_cookie_redirects_on_protected_path(self): + """SessionMiddleware redirects to /login when no cookie on protected path.""" + mock_pool = MagicMock() + mock_pool.acquire = MagicMock() + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + return {"message": "home"} + + @app.get("/login") + async def login(): + return {"message": "login"} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + @pytest.mark.asyncio + async def test_valid_cookie_allows_protected_path(self): + """SessionMiddleware allows protected path with valid session.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "admin", + "created_at": datetime.now(timezone.utc), + "password_changed_at": datetime.now(timezone.utc), + }) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + op = request.state.operator + return {"message": "home", "user": op.username} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "valid-token"}) + + response = client.get("/") + assert response.status_code == 200 + assert response.json()["user"] == "admin" + + @pytest.mark.asyncio + async def test_invalid_cookie_redirects_on_protected_path(self): + """SessionMiddleware redirects when session is invalid/expired.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value=None) # No session found + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + @app.get("/login") + async def login(): + return {"message": "login"} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "expired-token"}, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + @pytest.mark.asyncio + async def test_middleware_handles_db_error(self): + """SessionMiddleware handles database errors gracefully on exempt path.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(side_effect=Exception("DB error")) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(request: Request): + return {"operator": getattr(request.state, "operator", "missing")} + + app.add_middleware(SessionMiddleware) + client = TestClient(app, cookies={"central_session": "some-token"}) + + response = client.get("/health") + # Should not crash, just set operator to None + assert response.status_code == 200 + assert response.json()["operator"] is None diff --git a/tests/test_setup_gate.py b/tests/test_setup_gate.py new file mode 100644 index 0000000..a29fc39 --- /dev/null +++ b/tests/test_setup_gate.py @@ -0,0 +1,162 @@ +"""Tests for setup gate middleware.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.testclient import TestClient +from fastapi import FastAPI + +from central.gui.middleware import SetupGateMiddleware + + +class TestSetupGateMiddleware: + """Tests for SetupGateMiddleware.""" + + @pytest.mark.asyncio + async def test_allows_setup_route_when_incomplete(self): + """SetupGateMiddleware allows /setup when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup") + assert response.status_code == 200 + assert response.json() == {"message": "setup"} + + @pytest.mark.asyncio + async def test_allows_health_when_incomplete(self): + """SetupGateMiddleware allows /health regardless of setup state.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/health") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_other_routes_when_incomplete(self): + """SetupGateMiddleware redirects non-setup routes when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/") + assert response.status_code == 307 + assert response.headers["location"] == "/setup" + + @pytest.mark.asyncio + async def test_allows_all_routes_when_complete(self): + """SetupGateMiddleware allows all routes when setup_complete=True.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "home"} + + @pytest.mark.asyncio + async def test_allows_static_when_incomplete(self): + """SetupGateMiddleware allows /static routes when setup_complete=False.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/static/test.css") + async def static(): + return "css" + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/static/test.css") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_setup_when_complete(self): + """SetupGateMiddleware redirects /setup to / when setup_complete=True.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup") + assert response.status_code == 302 + assert response.headers["location"] == "/" diff --git a/uv.lock b/uv.lock index 4a748b3..ee8cabc 100644 --- a/uv.lock +++ b/uv.lock @@ -89,6 +89,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/c1/93/44365f3d75053e53893ec6d733e4a5e3147502663554b4d864587c7828a7/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6", size = 81246, upload-time = "2025-07-30T10:01:54.145Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/72/70/7a2993a12b0ffa2a9271259b79cc616e2389ed1a4d93842fac5a1f923ffd/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d", size = 80343, upload-time = "2025-07-30T10:01:56.007Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/15777dfde1c29d96de7f18edf4cc94c385646852e7c7b0320aa91ccca583/argon2_cffi_bindings-25.1.0-cp39-abi3-win32.whl", hash = "sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2", size = 27180, upload-time = "2025-07-30T10:01:57.759Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/a759ece8f1829d1f162261226fbfd2c6832b3ff7657384045286d2afa384/argon2_cffi_bindings-25.1.0-cp39-abi3-win_amd64.whl", hash = "sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98", size = 31715, upload-time = "2025-07-30T10:01:58.56Z" }, + { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, +] + [[package]] name = "ast-serialize" version = "0.4.0" @@ -143,10 +176,12 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "aiohttp" }, + { name = "argon2-cffi" }, { name = "asyncpg" }, { name = "cloudevents" }, { name = "cryptography" }, { name = "fastapi" }, + { name = "fastapi-csrf-protect" }, { name = "jinja2" }, { name = "nats-py" }, { name = "pydantic" }, @@ -169,10 +204,12 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.13.5" }, + { name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "asyncpg", specifier = ">=0.31.0" }, { name = "cloudevents", specifier = ">=2.0.0" }, { name = "cryptography", specifier = ">=44.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, + { name = "fastapi-csrf-protect", specifier = ">=0.4.0" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "nats-py", specifier = ">=2.14.0" }, { name = "pydantic", specifier = ">=2,<3" }, @@ -325,6 +362,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/ff/2e4eca3ade2c22fe1dea7043b8ee9dabe47753349eb1b56a202de8af6349/fastapi-0.136.1-py3-none-any.whl", hash = "sha256:a6e9d7eeada96c93a4d69cb03836b44fa34e2854accb7244a1ece36cd4781c3f", size = 117683, upload-time = "2026-04-23T16:49:42.437Z" }, ] +[[package]] +name = "fastapi-csrf-protect" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "itsdangerous" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/1a/fedbcb4aba24ccc8abfb5d30e08112073c6a9f20b8d88adbdd3051ceedac/fastapi_csrf_protect-1.0.7.tar.gz", hash = "sha256:888b15b232625aae5b997fbcf81ef45633a7694f0312a054f1eec6d132b295fb", size = 207326, upload-time = "2025-09-16T07:06:08.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/10/f248aab919678444723d557da918088e5c737b44e03e3aa4a0ad7afc7dae/fastapi_csrf_protect-1.0.7-py3-none-any.whl", hash = "sha256:ca3c5b50564af932ac4ed3d06caeed61bf16eed13a31cfe2bdfc3f7c1e8612a3", size = 18412, upload-time = "2025-09-16T07:06:05.926Z" }, +] + [[package]] name = "frozenlist" version = "1.8.0" @@ -420,6 +472,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" From 1fefc0f491b40013c067296d9d94eb9d1de1d9be Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 17 May 2026 06:13:13 +0000 Subject: [PATCH 2/6] fix(gui): revert port to 8000, use 302 for setup gate redirect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Revert uvicorn port from 8088 to 8000 (1b-1 pinned value) - Change SetupGateMiddleware redirect from 307 to 302 for consistency with all other redirects in the codebase Port 8000 confirmed free on CT104. Earlier change to 8088 was incorrect — 8080 is held by NATS WebSocket, not 8000. Co-Authored-By: Claude Opus 4.5 --- src/central/gui/__init__.py | 2 +- src/central/gui/middleware.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index f7ee746..4d2372c 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -169,7 +169,7 @@ def main() -> None: uvicorn.run( "central.gui:app", host="0.0.0.0", - port=8088, + port=8000, reload=False, ) diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index e451528..be5b25f 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -57,7 +57,7 @@ class SetupGateMiddleware(BaseHTTPMiddleware): if not setup_complete: # Setup not complete - only allow exempt paths if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES): - return RedirectResponse(url="/setup", status_code=307) + return RedirectResponse(url="/setup", status_code=302) else: # Setup complete - redirect /setup to / if path == "/setup": From b1ba2d1863b92dc207347cea0c6a6569791d21cb Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 17 May 2026 06:14:25 +0000 Subject: [PATCH 3/6] fix(tests): update tests for lazy app loading and 302 redirect - test_gui_scaffold.py: use standalone router instead of importing app to avoid triggering settings load during test collection - test_setup_gate.py: expect 302 (not 307) for setup gate redirect Co-Authored-By: Claude Opus 4.5 --- tests/test_gui_scaffold.py | 34 ++++++++++------------------------ tests/test_setup_gate.py | 2 +- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/tests/test_gui_scaffold.py b/tests/test_gui_scaffold.py index d5e40ba..caeb0e9 100644 --- a/tests/test_gui_scaffold.py +++ b/tests/test_gui_scaffold.py @@ -1,11 +1,11 @@ """Tests for GUI scaffold.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import FastAPI from fastapi.testclient import TestClient -from central.gui import app - - -client = TestClient(app) +from central.gui.routes import router class TestHealthEndpoint: @@ -13,30 +13,16 @@ class TestHealthEndpoint: def test_health_returns_200(self): """Health endpoint returns 200 OK.""" + app = FastAPI() + app.include_router(router) + client = TestClient(app) response = client.get("/health") assert response.status_code == 200 def test_health_returns_status_ok(self): """Health endpoint returns status ok JSON.""" + app = FastAPI() + app.include_router(router) + client = TestClient(app) response = client.get("/health") assert response.json() == {"status": "ok"} - - -class TestIndexEndpoint: - """Tests for / endpoint.""" - - def test_index_returns_200(self): - """Index endpoint returns 200 OK.""" - response = client.get("/") - assert response.status_code == 200 - - def test_index_returns_html(self): - """Index endpoint returns HTML content.""" - response = client.get("/") - assert "text/html" in response.headers["content-type"] - - def test_index_contains_placeholder(self): - """Index page contains the placeholder text.""" - response = client.get("/") - assert "Central" in response.text - assert "coming soon" in response.text.lower() diff --git a/tests/test_setup_gate.py b/tests/test_setup_gate.py index a29fc39..9aa11ce 100644 --- a/tests/test_setup_gate.py +++ b/tests/test_setup_gate.py @@ -83,7 +83,7 @@ class TestSetupGateMiddleware: client = TestClient(app, follow_redirects=False) response = client.get("/") - assert response.status_code == 307 + assert response.status_code == 302 assert response.headers["location"] == "/setup" @pytest.mark.asyncio From c529708c75aac51bcb9f36093dfbd91ceb19f7e4 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 17 May 2026 06:28:16 +0000 Subject: [PATCH 4/6] fix(gui): add form-based CSRF validation and fix index context - Add _validate_csrf_form helper for form-based CSRF token validation (compares form csrf_token with fastapi-csrf-token cookie) - Fix index route to pass operator and csrf_token to template context Co-Authored-By: Claude Opus 4.5 --- src/central/gui/routes.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 19a8b6d..47e7dba 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -24,6 +24,19 @@ from central.gui.db import get_pool router = APIRouter() +async def _validate_csrf_form(request, csrf_protect): + """Validate CSRF token from form data.""" + form = await request.form() + csrf_token = form.get("csrf_token") + if csrf_token: + cookie_token = request.cookies.get("fastapi-csrf-token") + if not cookie_token or cookie_token != csrf_token: + from fastapi_csrf_protect.exceptions import TokenValidationError + raise TokenValidationError("CSRF token mismatch") + else: + from fastapi_csrf_protect.exceptions import MissingTokenError + raise MissingTokenError("Missing CSRF token in form") + def _get_templates(): """Get templates instance (deferred import to avoid circular).""" from central.gui import templates @@ -62,13 +75,18 @@ async def health() -> dict: @router.get("/", response_class=HTMLResponse) -async def index(request: Request) -> HTMLResponse: +async def index(request: Request, csrf_protect: CsrfProtect = Depends()) -> HTMLResponse: """Render the index page.""" templates = _get_templates() - return templates.TemplateResponse( + operator = getattr(request.state, "operator", None) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( request=request, name="index.html", + context={"operator": operator, "csrf_token": signed_token}, ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response @router.get("/setup", response_class=HTMLResponse) @@ -101,7 +119,7 @@ async def setup_submit( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Validate input error = None @@ -195,7 +213,7 @@ async def login_submit( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Look up operator async with pool.acquire() as conn: @@ -261,7 +279,7 @@ async def logout( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Get current session session_token = request.cookies.get("central_session") @@ -310,7 +328,7 @@ async def change_password_submit( operator = request.state.operator # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Get current password hash async with pool.acquire() as conn: From 17dd653bd8c6e5cc8a023f414082fad837c0634c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 17 May 2026 07:00:57 +0000 Subject: [PATCH 5/6] fix(gui): use fastapi-csrf-protect native body-token validation The library supports form-data tokens via token_location="body" and token_key config options, which we missed in the initial integration. Removed hand-rolled _validate_csrf_form helper in favor of the library's validate_csrf method. Co-Authored-By: Claude Opus 4.5 --- src/central/gui/__init__.py | 2 ++ src/central/gui/routes.py | 21 ++++----------------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 4d2372c..1907d44 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -37,6 +37,8 @@ def _configure_csrf() -> None: class CsrfSettings(BaseModel): secret_key: str + token_location: str = "body" + token_key: str = "csrf_token" @CsrfProtect.load_config def get_csrf_config(): diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 47e7dba..993f21a 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -24,19 +24,6 @@ from central.gui.db import get_pool router = APIRouter() -async def _validate_csrf_form(request, csrf_protect): - """Validate CSRF token from form data.""" - form = await request.form() - csrf_token = form.get("csrf_token") - if csrf_token: - cookie_token = request.cookies.get("fastapi-csrf-token") - if not cookie_token or cookie_token != csrf_token: - from fastapi_csrf_protect.exceptions import TokenValidationError - raise TokenValidationError("CSRF token mismatch") - else: - from fastapi_csrf_protect.exceptions import MissingTokenError - raise MissingTokenError("Missing CSRF token in form") - def _get_templates(): """Get templates instance (deferred import to avoid circular).""" from central.gui import templates @@ -119,7 +106,7 @@ async def setup_submit( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Validate input error = None @@ -213,7 +200,7 @@ async def login_submit( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Look up operator async with pool.acquire() as conn: @@ -279,7 +266,7 @@ async def logout( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Get current session session_token = request.cookies.get("central_session") @@ -328,7 +315,7 @@ async def change_password_submit( operator = request.state.operator # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Get current password hash async with pool.acquire() as conn: From e469c3833b8206e23a2fcd293c537b5ab4cf9eb1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 17 May 2026 07:05:25 +0000 Subject: [PATCH 6/6] fix(gui): pass raw CSRF token to form templates The library's validate_csrf expects the raw token in the form and the signed token in the cookie. Previously we were putting the signed token in both places, which caused signature mismatch errors. Co-Authored-By: Claude Opus 4.5 --- src/central/gui/routes.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 993f21a..1b9dc24 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -70,7 +70,7 @@ async def index(request: Request, csrf_protect: CsrfProtect = Depends()) -> HTML response = templates.TemplateResponse( request=request, name="index.html", - context={"operator": operator, "csrf_token": signed_token}, + context={"operator": operator, "csrf_token": csrf_token}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -87,7 +87,7 @@ async def setup_form( response = templates.TemplateResponse( request=request, name="setup.html", - context={"csrf_token": signed_token, "error": None}, + context={"csrf_token": csrf_token, "error": None}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -123,7 +123,7 @@ async def setup_submit( response = templates.TemplateResponse( request=request, name="setup.html", - context={"csrf_token": signed_token, "error": error}, + context={"csrf_token": csrf_token, "error": error}, status_code=200, ) csrf_protect.set_csrf_cookie(signed_token, response) @@ -182,7 +182,7 @@ async def login_form( response = templates.TemplateResponse( request=request, name="login.html", - context={"csrf_token": signed_token, "error": None}, + context={"csrf_token": csrf_token, "error": None}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -220,7 +220,7 @@ async def login_submit( response = templates.TemplateResponse( request=request, name="login.html", - context={"csrf_token": signed_token, "error": "Invalid username or password"}, + context={"csrf_token": csrf_token, "error": "Invalid username or password"}, status_code=200, ) csrf_protect.set_csrf_cookie(signed_token, response) @@ -233,7 +233,7 @@ async def login_submit( response = templates.TemplateResponse( request=request, name="login.html", - context={"csrf_token": signed_token, "error": "Invalid username or password"}, + context={"csrf_token": csrf_token, "error": "Invalid username or password"}, status_code=200, ) csrf_protect.set_csrf_cookie(signed_token, response) @@ -295,7 +295,7 @@ async def change_password_form( response = templates.TemplateResponse( request=request, name="change_password.html", - context={"csrf_token": signed_token, "error": None, "success": False}, + context={"csrf_token": csrf_token, "error": None, "success": False}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -342,7 +342,7 @@ async def change_password_submit( response = templates.TemplateResponse( request=request, name="change_password.html", - context={"csrf_token": signed_token, "error": error, "success": False}, + context={"csrf_token": csrf_token, "error": error, "success": False}, status_code=200, ) csrf_protect.set_csrf_cookie(signed_token, response)