diff --git a/pyproject.toml b/pyproject.toml index d1d7f8e..c101844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "asyncpg>=0.31.0", "cloudevents>=2.0.0", "cryptography>=44.0.0", - "fastapi-csrf-protect>=0.4.0", "fastapi>=0.115.0", "jinja2>=3.1.6", "nats-py>=2.14.0", diff --git a/sql/migrations/013_add_session_csrf_token.sql b/sql/migrations/013_add_session_csrf_token.sql new file mode 100644 index 0000000..fbd0d11 --- /dev/null +++ b/sql/migrations/013_add_session_csrf_token.sql @@ -0,0 +1,9 @@ +-- Add CSRF token column to sessions table +-- Session-bound CSRF tokens prevent race conditions from cookie rotation + +ALTER TABLE config.sessions + ADD COLUMN csrf_token TEXT NOT NULL + DEFAULT encode(gen_random_bytes(32), 'hex'); + +-- Comment +COMMENT ON COLUMN config.sessions.csrf_token IS 'Session-bound CSRF token for synchronizer token pattern'; diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 20d79aa..71a302b 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -29,23 +29,6 @@ _cleanup_task: asyncio.Task | None = None _app: FastAPI | None = None -def _configure_csrf() -> None: - """Configure CSRF protection. Must be called before app starts.""" - from fastapi_csrf_protect import CsrfProtect - from pydantic import BaseModel - from central.bootstrap_config import get_settings - - class CsrfSettings(BaseModel): - secret_key: str - token_location: str = "body" - token_key: str = "csrf_token" - - @CsrfProtect.load_config - def get_csrf_config(): - settings = get_settings() - return CsrfSettings(secret_key=settings.csrf_secret) - - async def _session_cleanup_loop() -> None: """Periodically clean up expired sessions.""" global _shutdown_event @@ -117,9 +100,6 @@ def _create_app() -> FastAPI: from central.gui.middleware import SessionMiddleware, SetupGateMiddleware from central.gui.routes import router - # Configure CSRF before creating app - _configure_csrf() - app = FastAPI( title="Central GUI", lifespan=lifespan, @@ -137,45 +117,214 @@ def _create_app() -> FastAPI: app.include_router(router) # CSRF exception handler - return friendly error instead of 500 - from fastapi_csrf_protect.exceptions import CsrfProtectError + from central.gui.auth import CsrfValidationError + from central.gui.csrf import generate_pre_auth_csrf, set_pre_auth_csrf_cookie + from central.bootstrap_config import get_settings from fastapi.responses import RedirectResponse - @app.exception_handler(CsrfProtectError) - async def csrf_exception_handler(request, exc: CsrfProtectError): - from fastapi_csrf_protect import CsrfProtect - - csrf_protect = CsrfProtect() - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() - + @app.exception_handler(CsrfValidationError) + async def csrf_exception_handler(request, exc: CsrfValidationError): + from central.gui.db import get_pool + + settings = get_settings() + # For pre-auth paths, generate a new pre-auth token + # For session paths, we'll just show the error (session token stays valid) + csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) + error_msg = "Your session expired. Please try again." + if request.url.path == "/login": response = templates.TemplateResponse( request=request, name="login.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + context={"csrf_token": csrf_token, "error": error_msg}, ) - csrf_protect.set_csrf_cookie(signed_token, response) + set_pre_auth_csrf_cookie(response, signed_token) return response + elif request.url.path == "/setup": + # /setup is a redirect path now, not a form + return RedirectResponse("/setup", status_code=302) + + elif request.url.path == "/setup/operator": response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + name="setup_operator.html", + context={"csrf_token": csrf_token, "error": error_msg, "form_data": None}, ) - csrf_protect.set_csrf_cookie(signed_token, response) + set_pre_auth_csrf_cookie(response, signed_token) return response + + elif request.url.path == "/setup/system": + pool = get_pool() + system = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "© OpenStreetMap contributors", + } + if pool: + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + if row: + system = { + "map_tile_url": row["map_tile_url"], + "map_attribution": row["map_attribution"], + } + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": error_msg, + "errors": None, + "form_data": None, + "system": system, + }, + ) + set_pre_auth_csrf_cookie(response, signed_token) + return response + + elif request.url.path == "/setup/keys": + pool = get_pool() + keys = [] + if pool: + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": None, + "error": error_msg, + }, + ) + set_pre_auth_csrf_cookie(response, signed_token) + return response + + elif request.url.path == "/setup/adapters": + pool = get_pool() + adapters = [] + api_keys = [] + tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = "© OpenStreetMap contributors" + if pool: + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" + ) + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + key_rows = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + api_keys = [{"alias": k["alias"]} for k in key_rows] + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + if sys_row: + tile_url = sys_row["map_tile_url"] + tile_attribution = sys_row["map_attribution"] + except Exception: + pass + + # Import helper functions for valid values + from central.gui.routes import _get_valid_satellites, _get_valid_feeds + + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": api_keys, + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": error_msg, + "errors": None, + "form_data": None, + }, + ) + set_pre_auth_csrf_cookie(response, signed_token) + return response + + elif request.url.path == "/setup/finish": + pool = get_pool() + operator_count = 0 + key_count = 0 + system = {"map_tile_url": ""} + adapters = [] + if pool: + try: + async with pool.acquire() as conn: + operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + if sys_row: + system = {"map_tile_url": sys_row["map_tile_url"]} + rows = await conn.fetch( + "SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name" + ) + adapters = [ + {"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]} + for row in rows + ] + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_finish.html", + context={ + "csrf_token": csrf_token, + "operator_count": operator_count, + "key_count": key_count, + "system": system, + "adapters": adapters, + "error": error_msg, + }, + ) + set_pre_auth_csrf_cookie(response, signed_token) + return response + elif request.url.path == "/logout": return RedirectResponse("/login", status_code=302) + elif request.url.path == "/change-password": response = templates.TemplateResponse( request=request, name="change_password.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + context={"csrf_token": csrf_token, "error": error_msg}, ) - csrf_protect.set_csrf_cookie(signed_token, response) + set_pre_auth_csrf_cookie(response, signed_token) return response + elif request.url.path.startswith("/adapters/"): # Redirect back to adapters list return RedirectResponse("/adapters", status_code=302) + else: # Fallback: redirect to login return RedirectResponse("/login", status_code=302) diff --git a/src/central/gui/audit.py b/src/central/gui/audit.py index 1bdb66b..b7cfd47 100644 --- a/src/central/gui/audit.py +++ b/src/central/gui/audit.py @@ -14,6 +14,8 @@ STREAM_UPDATE = "stream.update" API_KEY_CREATE = "api_key.create" API_KEY_ROTATE = "api_key.rotate" API_KEY_DELETE = "api_key.delete" +SYSTEM_UPDATE = "system.update" +SETUP_COMPLETE = "setup.complete" async def write_audit( diff --git a/src/central/gui/auth.py b/src/central/gui/auth.py index 3b74ac0..ca9ec73 100644 --- a/src/central/gui/auth.py +++ b/src/central/gui/auth.py @@ -12,6 +12,11 @@ from argon2.exceptions import VerifyMismatchError _hasher = PasswordHasher() +class CsrfValidationError(Exception): + """Raised when CSRF token validation fails.""" + pass + + @dataclass class Operator: """Operator account.""" @@ -46,39 +51,46 @@ def generate_token() -> str: return secrets.token_urlsafe(32) +def generate_csrf_token() -> str: + """Generate a cryptographically secure CSRF token.""" + return secrets.token_hex(32) + + async def create_session( conn: Any, # asyncpg.Connection operator_id: int, lifetime_days: int, -) -> tuple[str, datetime]: +) -> tuple[str, datetime, str]: """Create a new session for an operator. - Returns (token, expires_at). + Returns (token, expires_at, csrf_token). """ token = generate_token() + csrf_token = generate_csrf_token() expires_at = datetime.now(timezone.utc) + timedelta(days=lifetime_days) await conn.execute( """ - INSERT INTO config.sessions (token, operator_id, expires_at) - VALUES ($1, $2, $3) + INSERT INTO config.sessions (token, operator_id, expires_at, csrf_token) + VALUES ($1, $2, $3, $4) """, token, operator_id, expires_at, + csrf_token, ) - return token, expires_at + return token, expires_at, csrf_token -async def get_session(conn: Any, token: str) -> Operator | None: - """Look up a session and return the associated operator. +async def get_session(conn: Any, token: str) -> tuple[Operator, str] | None: + """Look up a session and return the associated operator and csrf_token. - Returns None if token is invalid or expired. + Returns (Operator, csrf_token) or None if token is invalid or expired. """ row = await conn.fetchrow( """ - SELECT o.id, o.username, o.created_at, o.password_changed_at + SELECT o.id, o.username, o.created_at, o.password_changed_at, s.csrf_token FROM config.sessions s JOIN config.operators o ON s.operator_id = o.id WHERE s.token = $1 AND s.expires_at > now() @@ -89,12 +101,14 @@ async def get_session(conn: Any, token: str) -> Operator | None: if row is None: return None - return Operator( + operator = Operator( id=row["id"], username=row["username"], created_at=row["created_at"], password_changed_at=row.get("password_changed_at"), ) + + return operator, row["csrf_token"] async def delete_session(conn: Any, token: str) -> None: diff --git a/src/central/gui/csrf.py b/src/central/gui/csrf.py new file mode 100644 index 0000000..0d6198f --- /dev/null +++ b/src/central/gui/csrf.py @@ -0,0 +1,72 @@ +"""Pre-auth CSRF protection for login and setup/operator pages. + +These routes cannot use session-bound CSRF because no session exists yet. +Uses a simple cookie-based pattern with short-lived tokens. +""" + +import secrets +from typing import Optional + +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from starlette.requests import Request +from starlette.responses import Response + + +# 10 minute max age for pre-auth CSRF tokens +PRE_AUTH_CSRF_MAX_AGE = 600 +PRE_AUTH_CSRF_COOKIE = "central_preauth_csrf" + + +def _get_serializer(secret_key: str) -> URLSafeTimedSerializer: + """Get a timed serializer for CSRF tokens.""" + return URLSafeTimedSerializer(secret_key, salt="preauth-csrf") + + +def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]: + """Generate a pre-auth CSRF token pair. + + Returns (plain_token, signed_token). + The plain_token goes in the form, signed_token goes in the cookie. + """ + plain_token = secrets.token_hex(32) + serializer = _get_serializer(secret_key) + signed_token = serializer.dumps(plain_token) + return plain_token, signed_token + + +def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None: + """Set the pre-auth CSRF cookie on a response.""" + response.set_cookie( + PRE_AUTH_CSRF_COOKIE, + signed_token, + max_age=PRE_AUTH_CSRF_MAX_AGE, + path="/", + httponly=True, + samesite="lax", + ) + + +def validate_pre_auth_csrf( + request: Request, + form_token: str, + secret_key: str, +) -> bool: + """Validate a pre-auth CSRF token. + + Returns True if valid, False otherwise. + """ + cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE) + if not cookie_value or not form_token: + return False + + serializer = _get_serializer(secret_key) + try: + expected_token = serializer.loads(cookie_value, max_age=PRE_AUTH_CSRF_MAX_AGE) + return secrets.compare_digest(form_token, expected_token) + except (BadSignature, SignatureExpired): + return False + + +def unset_pre_auth_csrf_cookie(response: Response) -> None: + """Remove the pre-auth CSRF cookie.""" + response.delete_cookie(PRE_AUTH_CSRF_COOKIE, path="/") diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index be5b25f..155112b 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -12,11 +12,10 @@ from central.gui.db import get_pool logger = logging.getLogger(__name__) # Paths that don't require setup to be complete -SETUP_EXEMPT_PATHS = {"/setup", "/health"} -SETUP_EXEMPT_PREFIXES = ("/static/",) +SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication -AUTH_EXEMPT_PATHS = {"/setup", "/login", "/health"} +AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} AUTH_EXEMPT_PREFIXES = ("/static/",) @@ -30,6 +29,35 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False +async def _get_wizard_redirect_step(conn) -> str: + """Determine which wizard step to redirect to based on DB state.""" + # Check if any operators exist + op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + if op_count == 0: + return "/setup/operator" + + # Check if system settings have been configured (map_tile_url not default) + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + if sys_row is None or sys_row["map_tile_url"] == default_tile: + return "/setup/system" + + # Keys step is optional, so check adapters have been reviewed + # We consider adapters reviewed if any adapter has a non-null updated_at + # (meaning it was explicitly saved during setup) + adapters_touched = await conn.fetchval( + "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" + ) + if adapters_touched == 0: + # Go to keys first, then adapters + return "/setup/keys" + + # All steps done, go to finish + return "/setup/finish" + + class SetupGateMiddleware(BaseHTTPMiddleware): """Redirect to /setup if setup is not complete.""" @@ -55,25 +83,44 @@ class SetupGateMiddleware(BaseHTTPMiddleware): return await call_next(request) if not setup_complete: - # Setup not complete - only allow exempt paths - if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES): + # Setup not complete - only allow setup paths and static/health + if path.startswith("/setup"): + # Allow all /setup/* paths (handler will enforce auth) + # But /setup with no subpath should redirect to appropriate step + if path == "/setup" or path == "/setup/": + try: + async with pool.acquire() as conn: + redirect_step = await _get_wizard_redirect_step(conn) + return RedirectResponse(url=redirect_step, status_code=302) + except Exception: + logger.warning("Failed to determine wizard step", exc_info=True) + return RedirectResponse(url="/setup/operator", status_code=302) + return await call_next(request) + elif path == "/health" or path.startswith("/static/"): + return await call_next(request) + elif path == "/login": + # During setup, login redirects to /setup + return RedirectResponse(url="/setup", status_code=302) + else: + # All other paths redirect to /setup return RedirectResponse(url="/setup", status_code=302) else: - # Setup complete - redirect /setup to / - if path == "/setup": + # Setup complete - redirect /setup* to / + if path.startswith("/setup"): return RedirectResponse(url="/", status_code=302) return await call_next(request) class SessionMiddleware(BaseHTTPMiddleware): - """Load session from cookie and attach operator to request.state.""" + """Load session from cookie and attach operator + csrf_token to request.state.""" async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path - # Initialize operator to None + # Initialize state request.state.operator = None + request.state.csrf_token = None # Try to load session from cookie session_token = request.cookies.get("central_session") @@ -82,11 +129,15 @@ class SessionMiddleware(BaseHTTPMiddleware): if pool is not None: try: async with pool.acquire() as conn: - operator = await get_session(conn, session_token) - request.state.operator = operator + result = await get_session(conn, session_token) + if result is not None: + operator, csrf_token = result + request.state.operator = operator + request.state.csrf_token = csrf_token except Exception: logger.warning("Failed to load session", exc_info=True) request.state.operator = None + request.state.csrf_token = None # Check if auth is required if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 37a5c37..afb12a0 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -9,9 +9,16 @@ logger = logging.getLogger("central.gui.routes") from fastapi import APIRouter, Depends, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response -from fastapi_csrf_protect import CsrfProtect +from central.bootstrap_config import get_settings +from central.gui.csrf import ( + generate_pre_auth_csrf, + set_pre_auth_csrf_cookie, + validate_pre_auth_csrf, + unset_pre_auth_csrf_cookie, +) from central.gui.auth import ( + CsrfValidationError, create_session, delete_session, hash_password, @@ -28,7 +35,9 @@ from central.gui.audit import ( AUTH_LOGOUT, AUTH_PASSWORD_CHANGE, OPERATOR_CREATE, + SETUP_COMPLETE, STREAM_UPDATE, + SYSTEM_UPDATE, write_audit, ) from central.gui.db import get_pool @@ -101,17 +110,16 @@ async def health() -> dict: @router.get("/", response_class=HTMLResponse) -async def index(request: Request, csrf_protect: CsrfProtect = Depends()) -> HTMLResponse: +async def index(request: Request) -> HTMLResponse: """Render the index page.""" templates = _get_templates() operator = getattr(request.state, "operator", None) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="index.html", context={"operator": operator, "csrf_token": csrf_token}, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -252,37 +260,83 @@ async def dashboard_polls(request: Request) -> HTMLResponse: ) -@router.get("/setup", response_class=HTMLResponse) -async def setup_form( +# ============================================================================= +# Setup Wizard routes +# ============================================================================= + + +@router.get("/setup/operator", response_class=HTMLResponse) +async def setup_operator_form( request: Request, - csrf_protect: CsrfProtect = Depends(), ) -> HTMLResponse: - """Render the setup form.""" + """Render the setup operator form (step 1).""" templates = _get_templates() - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + pool = get_pool() + settings = get_settings() + csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) + + # Check if operator already exists + existing_operator = None + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT username FROM config.operators ORDER BY id LIMIT 1" + ) + if row: + existing_operator = {"username": row["username"]} + response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": None}, + name="setup_operator.html", + context={ + "csrf_token": csrf_token, + "error": None, + "form_data": None, + "existing_operator": existing_operator, + }, ) - csrf_protect.set_csrf_cookie(signed_token, response) + set_pre_auth_csrf_cookie(response, signed_token) return response -@router.post("/setup") -async def setup_submit( +@router.post("/setup/operator") +async def setup_operator_submit( request: Request, username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), - csrf_protect: CsrfProtect = Depends(), + ) -> Response: - """Process the setup form.""" + """Process the setup operator form (step 1).""" templates = _get_templates() pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + settings = get_settings() + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + raise CsrfValidationError("Invalid CSRF token") + + # Check if operator already exists (single-operator-per-install design) + async with pool.acquire() as conn: + count = await conn.fetchval("SELECT count(*) FROM config.operators") + if count > 0: + # Operator already exists — render confirmation page + existing = await conn.fetchrow( + "SELECT username FROM config.operators ORDER BY id LIMIT 1" + ) + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_operator.html", + context={ + "csrf_token": csrf_token, + "error": None, + "form_data": None, + "existing_operator": {"username": existing["username"]}, + }, + ) + return response # Validate input error = None @@ -295,14 +349,18 @@ async def setup_submit( error = str(e) if error: - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": error}, + name="setup_operator.html", + context={ + "csrf_token": csrf_token, + "error": error, + "form_data": {"username": username}, + "existing_operator": None, + }, status_code=200, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Create operator @@ -334,33 +392,673 @@ async def setup_submit( lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 # Create session - token, expires_at = await create_session(conn, operator_id, lifetime_days) + token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) + # Redirect to next step with session cookie + response = RedirectResponse(url="/setup/system", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.get("/setup/system", response_class=HTMLResponse) +async def setup_system_form( + request: Request, + +) -> HTMLResponse: + """Render the system settings form (step 2).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": None, + "form_data": None, + "system": system, + }, + ) + return response + + +@router.post("/setup/system") +async def setup_system_submit( + request: Request, + +) -> Response: + """Process the system settings form (step 2).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") + + form = await request.form() + map_tile_url = form.get("map_tile_url", "").strip() + map_attribution = form.get("map_attribution", "").strip() + + form_data = { + "map_tile_url": map_tile_url, + "map_attribution": map_attribution, + } + + errors: dict[str, str] = {} + + # Validate map_tile_url + if not map_tile_url: + errors["map_tile_url"] = "Map tile URL is required" + elif "{z}" not in map_tile_url or "{x}" not in map_tile_url or "{y}" not in map_tile_url: + errors["map_tile_url"] = "URL must contain {z}, {x}, and {y} placeholders" + + # Validate map_attribution + if not map_attribution: + errors["map_attribution"] = "Map attribution is required" + + async with pool.acquire() as conn: + if errors: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "", + "map_attribution": row["map_attribution"] if row else "", + } + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": system, + }, + status_code=200, + ) + return response + + # Get current values for audit + old_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + before = { + "map_tile_url": old_row["map_tile_url"] if old_row else None, + "map_attribution": old_row["map_attribution"] if old_row else None, + } + + # Update system settings + await conn.execute( + """ + UPDATE config.system + SET map_tile_url = $1, map_attribution = $2 + WHERE id = true + """, + map_tile_url, + map_attribution, + ) + + # Write audit log + await write_audit( + conn, + SYSTEM_UPDATE, + operator_id=operator.id, + target="system", + before=before, + after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, + ) + + return RedirectResponse(url="/setup/keys", status_code=302) + + +@router.get("/setup/keys", response_class=HTMLResponse) +async def setup_keys_form( + request: Request, + +) -> HTMLResponse: + """Render the API keys form (step 3).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + from central.crypto import encrypt + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": None, + }, + ) + return response + + +@router.post("/setup/keys") +async def setup_keys_submit( + request: Request, + +) -> Response: + """Process the API keys form (step 3).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") + + form = await request.form() + action = form.get("action", "add") + + # If action is "next", redirect to adapters step + if action == "next": + return RedirectResponse(url="/setup/adapters", status_code=302) + + from central.crypto import encrypt + + templates = _get_templates() + pool = get_pool() + + # Otherwise, add a new key + alias = form.get("alias", "").strip() + plaintext_key = form.get("plaintext_key", "") + + form_data = {"alias": alias} + errors: dict[str, str] = {} + + # Validate alias + if not alias: + errors["alias"] = "Alias is required" + elif len(alias) > 64: + errors["alias"] = "Alias must be at most 64 characters" + elif not ALIAS_REGEX.match(alias): + errors["alias"] = "Alias must contain only letters, numbers, and underscores" + + # Validate plaintext_key + if not plaintext_key: + errors["plaintext_key"] = "API key is required" + elif len(plaintext_key) > 4096: + errors["plaintext_key"] = "API key must be at most 4096 characters" + + async with pool.acquire() as conn: + if not errors: + # Check if alias already exists + existing = await conn.fetchrow( + "SELECT alias FROM config.api_keys WHERE alias = $1", + alias, + ) + if existing: + errors["alias"] = "An API key with this alias already exists" + + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + + if errors: + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, + ) + return response + + # Encrypt the key + encrypted_value = encrypt(plaintext_key.encode()) + + # Insert the new key + row = await conn.fetchrow( + """ + INSERT INTO config.api_keys (alias, encrypted_value) + VALUES ($1, $2) + RETURNING created_at + """, + alias, + encrypted_value, + ) + + # Write audit log (no plaintext!) + await write_audit( + conn, + API_KEY_CREATE, + operator_id=operator.id, + target=alias, + before=None, + after={"alias": alias, "created_at": row["created_at"].isoformat()}, + ) + + # Refresh keys list + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + + # Re-render with success message + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": f"API key '{alias}' added successfully.", + }, + ) + return response + + +@router.get("/setup/adapters", response_class=HTMLResponse) +async def setup_adapters_form( + request: Request, + +) -> HTMLResponse: + """Render the adapters configuration form (step 4).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + adapters = [] + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + + # Get API keys for dropdown + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + + # Get map tile settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": [{"alias": k["alias"]} for k in api_keys], + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": None, + "errors": None, + "form_data": None, + }, + ) + return response + + +@router.post("/setup/adapters") +async def setup_adapters_submit( + request: Request, + +) -> Response: + """Process the adapters configuration form (step 4).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") + + form = await request.form() + errors: dict[str, str] = {} + + async with pool.acquire() as conn: + # Get current adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + + for row in rows: + adapter_name = row["name"] + current_settings = row["settings"] or {} + new_settings = dict(current_settings) + + # Parse enabled + enabled = f"{adapter_name}_enabled" in form + + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = row["cadence_s"] + + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + key_exists = await conn.fetchrow( + "SELECT 1 FROM config.api_keys WHERE alias = $1", + api_key_alias, + ) + if not key_exists: + errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = f"Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + # Store parsed data for re-render on error or update + if not errors.get(f"{adapter_name}_cadence_s"): + # Update adapter + await conn.execute( + """ + UPDATE config.adapters + SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() + WHERE name = $4 + """, + enabled, + cadence_s, + new_settings, + adapter_name, + ) + + # If any errors, re-render + if errors: + adapters = [] + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": [{"alias": k["alias"]} for k in api_keys], + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + return response + + return RedirectResponse(url="/setup/finish", status_code=302) + + +@router.get("/setup/finish", response_class=HTMLResponse) +async def setup_finish_form( + request: Request, + +) -> HTMLResponse: + """Render the finish setup page (step 5).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + # Get counts + operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + + # Get system settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": sys_row["map_tile_url"] if sys_row else "", + } + + # Get adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s + FROM config.adapters + ORDER BY name + """ + ) + adapters = [ + { + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + } + for row in rows + ] + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_finish.html", + context={ + "csrf_token": csrf_token, + "operator_count": operator_count, + "key_count": key_count, + "system": system, + "adapters": adapters, + }, + ) + return response + + +@router.post("/setup/finish") +async def setup_finish_submit( + request: Request, + +) -> Response: + """Complete the setup wizard.""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + pool = get_pool() + + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") + + async with pool.acquire() as conn: # Mark setup complete await conn.execute( "UPDATE config.system SET setup_complete = true WHERE id = true" ) - # Redirect with session cookie - response = RedirectResponse(url="/", status_code=302) - _set_session_cookie(response, token, lifetime_days * 86400) - return response + # Write audit log + await write_audit( + conn, + SETUP_COMPLETE, + operator_id=operator.id, + target="system", + ) + + return RedirectResponse(url="/", status_code=302) @router.get("/login", response_class=HTMLResponse) async def login_form( request: Request, - csrf_protect: CsrfProtect = Depends(), ) -> HTMLResponse: """Render the login form.""" templates = _get_templates() - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + settings = get_settings() + csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) response = templates.TemplateResponse( request=request, name="login.html", context={"csrf_token": csrf_token, "error": None}, ) - csrf_protect.set_csrf_cookie(signed_token, response) + set_pre_auth_csrf_cookie(response, signed_token) return response @@ -369,14 +1067,18 @@ async def login_submit( request: Request, username: str = Form(...), password: str = Form(...), - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Process the login form.""" templates = _get_templates() pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + settings = get_settings() + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + raise CsrfValidationError("Invalid CSRF token") # Look up operator async with pool.acquire() as conn: @@ -392,27 +1094,25 @@ async def login_submit( if row is None: # Unknown user - still audit the attempt await write_audit(conn, AUTH_LOGIN_FAILED, target=username) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="login.html", context={"csrf_token": csrf_token, "error": "Invalid username or password"}, status_code=200, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Verify password if not verify_password(password, row["password_hash"]): await write_audit(conn, AUTH_LOGIN_FAILED, operator_id=row["id"], target=username) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="login.html", context={"csrf_token": csrf_token, "error": "Invalid username or password"}, status_code=200, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Get session lifetime @@ -422,7 +1122,7 @@ async def login_submit( lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 # Create session - token, expires_at = await create_session(conn, row["id"], lifetime_days) + token, expires_at, _ = await create_session(conn, row["id"], lifetime_days) # Audit login await write_audit(conn, AUTH_LOGIN, operator_id=row["id"], target=username) @@ -436,13 +1136,16 @@ async def login_submit( @router.post("/logout") async def logout( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Log out the current user.""" pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") # Get current session session_token = request.cookies.get("central_session") @@ -463,17 +1166,16 @@ async def logout( @router.get("/change-password", response_class=HTMLResponse) async def change_password_form( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> HTMLResponse: """Render the change password form.""" templates = _get_templates() - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="change_password.html", context={"csrf_token": csrf_token, "error": None, "success": False}, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -483,7 +1185,7 @@ async def change_password_submit( current_password: str = Form(...), new_password: str = Form(...), confirm_password: str = Form(...), - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Process the change password form.""" templates = _get_templates() @@ -491,7 +1193,10 @@ async def change_password_submit( operator = request.state.operator # Validate CSRF - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") # Get current password hash async with pool.acquire() as conn: @@ -514,14 +1219,13 @@ async def change_password_submit( error = str(e) if error: - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="change_password.html", context={"csrf_token": csrf_token, "error": error, "success": False}, status_code=200, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Update password @@ -556,7 +1260,7 @@ async def change_password_submit( @router.get("/adapters", response_class=HTMLResponse) async def adapters_list( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> HTMLResponse: """List all adapters.""" templates = _get_templates() @@ -584,7 +1288,7 @@ async def adapters_list( "updated_at": row["updated_at"], }) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="adapters_list.html", @@ -594,7 +1298,6 @@ async def adapters_list( "adapters": adapters, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -602,7 +1305,7 @@ async def adapters_list( async def adapters_edit_form( request: Request, name: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Render the adapter edit form.""" templates = _get_templates() @@ -644,7 +1347,7 @@ async def adapters_edit_form( "updated_at": row["updated_at"], } - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="adapters_edit.html", @@ -661,7 +1364,6 @@ async def adapters_edit_form( "tile_attribution": tile_attribution, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -669,7 +1371,7 @@ async def adapters_edit_form( async def adapters_edit_submit( request: Request, name: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Process the adapter edit form.""" templates = _get_templates() @@ -677,7 +1379,10 @@ async def adapters_edit_submit( operator = request.state.operator # Validate CSRF - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") # Parse form data form = await request.form() @@ -820,7 +1525,7 @@ async def adapters_edit_submit( tile_url = sys_row["map_tile_url"] if sys_row else "https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png" tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="adapters_edit.html", @@ -838,7 +1543,6 @@ async def adapters_edit_submit( }, status_code=200, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Build before state for audit @@ -889,7 +1593,7 @@ async def adapters_edit_submit( @router.get("/streams", response_class=HTMLResponse) async def streams_list( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> HTMLResponse: """List all streams with live data.""" from central.gui.nats import get_js @@ -972,7 +1676,7 @@ async def streams_list( streams.append(stream_data) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="streams_list.html", @@ -982,7 +1686,6 @@ async def streams_list( "streams": streams, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -990,7 +1693,7 @@ async def streams_list( async def streams_update( request: Request, name: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Update stream max_age_s.""" from central.gui.nats import get_js @@ -1000,7 +1703,10 @@ async def streams_update( operator = request.state.operator # Validate CSRF - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") form = await request.form() max_age_s_str = form.get("max_age_s", "").strip() @@ -1069,7 +1775,7 @@ async def streams_update( streams.append(stream_data) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="streams_list.html", @@ -1080,7 +1786,6 @@ async def streams_update( "errors": errors, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response old_max_age_s = row["max_age_s"] @@ -1116,7 +1821,7 @@ ALIAS_REGEX = re.compile(r'^[a-zA-Z0-9_]+$') @router.get("/api-keys", response_class=HTMLResponse) async def api_keys_list( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> HTMLResponse: """List all API keys.""" templates = _get_templates() @@ -1152,7 +1857,7 @@ async def api_keys_list( "used_by": [a["name"] for a in adapters], }) - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_list.html", @@ -1162,20 +1867,19 @@ async def api_keys_list( "keys": keys, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @router.get("/api-keys/new", response_class=HTMLResponse) async def api_keys_new( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> HTMLResponse: """Show form to add a new API key.""" templates = _get_templates() operator = request.state.operator - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_new.html", @@ -1184,14 +1888,13 @@ async def api_keys_new( "csrf_token": csrf_token, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @router.post("/api-keys", response_class=HTMLResponse) async def api_keys_create( request: Request, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Create a new API key.""" from central.crypto import encrypt @@ -1200,7 +1903,10 @@ async def api_keys_create( pool = get_pool() operator = request.state.operator - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") form = await request.form() alias = form.get("alias", "").strip() @@ -1223,7 +1929,7 @@ async def api_keys_create( errors["plaintext_key"] = "API key must be at most 4096 characters" if errors: - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_new.html", @@ -1234,7 +1940,6 @@ async def api_keys_create( "alias": alias, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Encrypt the key @@ -1249,7 +1954,7 @@ async def api_keys_create( if existing: errors["alias"] = "An API key with this alias already exists" - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_new.html", @@ -1260,7 +1965,6 @@ async def api_keys_create( "alias": alias, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Insert the new key @@ -1291,7 +1995,7 @@ async def api_keys_create( async def api_keys_edit( request: Request, alias: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Show form to rotate or delete an API key.""" templates = _get_templates() @@ -1329,7 +2033,7 @@ async def api_keys_edit( "used_by": [a["name"] for a in adapters], } - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_edit.html", @@ -1339,7 +2043,6 @@ async def api_keys_edit( "key": key, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response @@ -1347,7 +2050,7 @@ async def api_keys_edit( async def api_keys_rotate( request: Request, alias: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Rotate an API key.""" from central.crypto import encrypt @@ -1356,7 +2059,10 @@ async def api_keys_rotate( pool = get_pool() operator = request.state.operator - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") form = await request.form() new_plaintext_key = form.get("new_plaintext_key", "") @@ -1400,7 +2106,7 @@ async def api_keys_rotate( "used_by": [a["name"] for a in adapters], } - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_edit.html", @@ -1411,7 +2117,6 @@ async def api_keys_rotate( "errors": errors, }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response old_rotated_at = row["rotated_at"] @@ -1448,14 +2153,17 @@ async def api_keys_rotate( async def api_keys_delete( request: Request, alias: str, - csrf_protect: CsrfProtect = Depends(), + ) -> Response: """Delete an API key.""" templates = _get_templates() pool = get_pool() operator = request.state.operator - await csrf_protect.validate_csrf(request) + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") async with pool.acquire() as conn: row = await conn.fetchrow( @@ -1490,7 +2198,7 @@ async def api_keys_delete( "used_by": adapter_names, } - csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="api_keys_edit.html", @@ -1501,7 +2209,6 @@ async def api_keys_delete( "error": f"Cannot delete: used by {', '.join(adapter_names)}. Remove these references first.", }, ) - csrf_protect.set_csrf_cookie(signed_token, response) return response # Delete the key diff --git a/src/central/gui/templates/_region_picker.html b/src/central/gui/templates/_region_picker.html index 5c53bc9..9c9b211 100644 --- a/src/central/gui/templates/_region_picker.html +++ b/src/central/gui/templates/_region_picker.html @@ -59,6 +59,10 @@ maxZoom: 18 }).addTo(map); + // Ensure map renders correctly even if container has not + // finished laying out at init time + setTimeout(function() { map.invalidateSize(); }, 100); + // Create initial rectangle const bounds = L.latLngBounds( L.latLng(savedSouth, savedWest), @@ -69,11 +73,34 @@ map.fitBounds(bounds.pad(0.1)); // Create editable rectangle - const rectangle = L.rectangle(bounds, { + let rectangle = L.rectangle(bounds, { color: '#3388ff', weight: 2, fillOpacity: 0.2 - }).addTo(map); + }); + + // Set up Leaflet.draw for click-to-draw + const drawnItems = new L.FeatureGroup(); + drawnItems.addLayer(rectangle); + map.addLayer(drawnItems); + + const drawControl = new L.Control.Draw({ + draw: { + rectangle: { shapeOptions: { color: '#3388ff', weight: 2, + fillOpacity: 0.2 } }, + polyline: false, + polygon: false, + circle: false, + marker: false, + circlemarker: false + }, + edit: { + featureGroup: drawnItems, + edit: false, + remove: false + } + }); + map.addControl(drawControl); // Make rectangle editable rectangle.editing.enable(); @@ -96,13 +123,33 @@ // Listen for rectangle edit events rectangle.on('edit', updateInputs); + // When user draws a new rectangle, replace the existing one + map.on(L.Draw.Event.CREATED, function(e) { + drawnItems.clearLayers(); + rectangle = e.layer; + rectangle.setStyle({ color: '#3388ff', weight: 2, + fillOpacity: 0.2 }); + drawnItems.addLayer(rectangle); + rectangle.editing.enable(); + rectangle.on('edit', updateInputs); + updateInputs(); + }); + // Reset button document.getElementById('region-reset-btn').addEventListener('click', function() { const originalBounds = L.latLngBounds( L.latLng(savedSouth, savedWest), L.latLng(savedNorth, savedEast) ); - rectangle.setBounds(originalBounds); + drawnItems.clearLayers(); + rectangle = L.rectangle(originalBounds, { + color: '#3388ff', + weight: 2, + fillOpacity: 0.2 + }); + drawnItems.addLayer(rectangle); + rectangle.editing.enable(); + rectangle.on('edit', updateInputs); updateInputs(); }); diff --git a/src/central/gui/templates/_wizard_header.html b/src/central/gui/templates/_wizard_header.html new file mode 100644 index 0000000..941d18e --- /dev/null +++ b/src/central/gui/templates/_wizard_header.html @@ -0,0 +1,6 @@ +
+
+ Step {{ step }} of 5 — {{ step_name }} +
+ +
diff --git a/src/central/gui/templates/base_wizard.html b/src/central/gui/templates/base_wizard.html new file mode 100644 index 0000000..a3eacac --- /dev/null +++ b/src/central/gui/templates/base_wizard.html @@ -0,0 +1,24 @@ + + + + + + {% block title %}Central - Setup{% endblock %} + + + {% block head %}{% endblock %} + + + +
+ {% block content %}{% endblock %} +
+ + diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html new file mode 100644 index 0000000..9ff2f50 --- /dev/null +++ b/src/central/gui/templates/setup_adapters.html @@ -0,0 +1,256 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Configure Adapters{% endblock %} + +{% block head %} + + + + +{% endblock %} + +{% block content %} +{% with step=4, step_name="Configure Adapters" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

Configure Adapters

+

Enable and configure data source adapters. Each adapter polls an external API and normalizes events.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + {% for adapter in adapters %} +
+ {{ adapter.name }} + +
+ + {% if errors and errors.get(adapter.name + '_enabled') %} + {{ errors[adapter.name + '_enabled'] }} + {% endif %} + + + + {% if errors and errors.get(adapter.name + '_cadence_s') %} + {{ errors[adapter.name + '_cadence_s'] }} + {% endif %} + + {% if adapter.name == 'nws' %} + + + {% if errors and errors.get(adapter.name + '_contact_email') %} + {{ errors[adapter.name + '_contact_email'] }} + {% endif %} + {% endif %} + + {% if adapter.name == 'firms' %} + + + {% if errors and errors.get(adapter.name + '_api_key_alias') %} + {{ errors[adapter.name + '_api_key_alias'] }} + {% endif %} + + + {% for sat in valid_satellites %} + + {% endfor %} + {% endif %} + + {% if adapter.name == 'usgs_quake' %} + + + {% if errors and errors.get(adapter.name + '_feed') %} + {{ errors[adapter.name + '_feed'] }} + {% endif %} + {% endif %} + +

Region

+ {% set region = form_data if form_data else adapter.settings.region %} +
+ +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ {% if errors and errors.get(adapter.name + '_region') %} + {{ errors[adapter.name + '_region'] }} + {% endif %} +
+
+
+ {% endfor %} + +
+ ← Back + +
+
+
+ + +{% endblock %} diff --git a/src/central/gui/templates/setup_finish.html b/src/central/gui/templates/setup_finish.html new file mode 100644 index 0000000..bc250bc --- /dev/null +++ b/src/central/gui/templates/setup_finish.html @@ -0,0 +1,73 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Finish Setup{% endblock %} + +{% block content %} +{% with step=5, step_name="Finish Setup" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

Setup Complete

+

Review your configuration and finish the setup wizard.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +

Summary

+ + + + + + + + + + + + + + + + +
Operators{{ operator_count }} configured
API Keys{{ key_count }} configured
Map Tile URL{{ system.map_tile_url }}
+ +

Adapters

+ + + + + + + + + + {% for adapter in adapters %} + + + + + + {% endfor %} + +
AdapterStatusCadence
{{ adapter.name }} + {% if adapter.enabled %} + Enabled + {% else %} + Disabled + {% endif %} + {{ adapter.cadence_s }}s
+ +
+ +
+ ← Back + +
+
+
+{% endblock %} diff --git a/src/central/gui/templates/setup_keys.html b/src/central/gui/templates/setup_keys.html new file mode 100644 index 0000000..4c3f125 --- /dev/null +++ b/src/central/gui/templates/setup_keys.html @@ -0,0 +1,88 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - API Keys{% endblock %} + +{% block content %} +{% with step=3, step_name="API Keys" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

API Keys

+

Add API keys for adapters that require external service credentials (e.g., FIRMS).

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + + {% if success %} +

{{ success }}

+ {% endif %} + + {% if keys %} +

Existing Keys

+ + + + + + + + + {% for key in keys %} + + + + + {% endfor %} + +
AliasCreated
{{ key.alias }}{{ key.created_at.strftime('%Y-%m-%d %H:%M') if key.created_at else '(never)' }}
+ {% else %} +

No API keys configured yet.

+ {% endif %} + +

Add New Key

+
+ + + +
+
+ + + {% if errors and errors.alias %} + {{ errors.alias }} + {% else %} + Letters, numbers, and underscores only. + {% endif %} +
+
+ + + {% if errors and errors.plaintext_key %} + {{ errors.plaintext_key }} + {% else %} + Will be encrypted before storage. + {% endif %} +
+
+ + +
+ +
+ +
+ + +
+ ← Back + +
+
+
+{% endblock %} diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html new file mode 100644 index 0000000..36932b4 --- /dev/null +++ b/src/central/gui/templates/setup_operator.html @@ -0,0 +1,57 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Create Operator{% endblock %} + +{% block content %} +{% with step=1, step_name="Create Operator" %} +{% include "_wizard_header.html" %} +{% endwith %} + +{% if existing_operator %} +
+
+

Operator Already Configured

+
+

The operator account {{ existing_operator.username }} has been created.

+
+ Next → +
+
+{% else %} +
+
+

Create Operator Account

+

Create the initial operator account to manage Central.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + + + + + + + +
+
+{% endif %} +{% endblock %} diff --git a/src/central/gui/templates/setup_system.html b/src/central/gui/templates/setup_system.html new file mode 100644 index 0000000..c49cb9c --- /dev/null +++ b/src/central/gui/templates/setup_system.html @@ -0,0 +1,49 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - System Settings{% endblock %} + +{% block content %} +{% with step=2, step_name="System Settings" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

System Settings

+

Configure map tile provider for the region picker.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + + {% if errors and errors.map_tile_url %} + {{ errors.map_tile_url }} + {% endif %} + + + {% if errors and errors.map_attribution %} + {{ errors.map_attribution }} + {% endif %} + +
+ ← Back + +
+
+
+{% endblock %} diff --git a/tests/conftest.py b/tests/conftest.py index ad93825..97a4f5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,3 +48,49 @@ def mock_conn(): conn.fetchval = AsyncMock() conn.execute = AsyncMock() return conn + + +# CSRF fixtures for route tests + +@pytest.fixture +def bypass_pre_auth_csrf(): + """Patch pre-auth CSRF validation to always pass. + + Use for tests of pre-auth routes: /login, /setup/operator + """ + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_csrf_token", "test_signed_token")): + yield + + +@pytest.fixture +def bypass_session_csrf(): + """Create a mock request with session CSRF properly configured. + + Use for tests of authenticated routes that check request.state.csrf_token. + Returns a configured mock_request. + """ + request = MagicMock() + request.state.csrf_token = "test_csrf_token_12345" + request.state.operator = MagicMock() + request.state.operator.id = 1 + request.state.operator.username = "testuser" + + # Mock form() to return dict with matching CSRF token + form_data = {"csrf_token": "test_csrf_token_12345"} + + async def mock_form(): + return form_data + + request.form = mock_form + request._form_data = form_data # Allow tests to modify form data + + return request + + +@pytest.fixture +def patch_route_settings(): + """Patch get_settings in routes module.""" + with patch("central.gui.routes.get_settings") as mock: + mock.return_value.csrf_secret = "test-csrf-secret-for-testing-only-32chars" + yield mock diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 17352f0..fa25c8b 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -55,13 +55,9 @@ class TestAdaptersListAuthenticated: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_list(mock_request, mock_csrf) + result = await adapters_list(mock_request) # Verify template was called with adapters call_args = mock_templates.TemplateResponse.call_args @@ -105,13 +101,9 @@ class TestAdaptersEditForm: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "nws", mock_csrf) + result = await adapters_edit_form(mock_request, "nws") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -133,11 +125,8 @@ class TestAdaptersEditForm: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "nonexistent", mock_csrf) + result = await adapters_edit_form(mock_request, "nonexistent") assert result.status_code == 404 @@ -156,7 +145,9 @@ class TestAdaptersEditSubmit: # Mock form data mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -183,12 +174,9 @@ class TestAdaptersEditSubmit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") assert result.status_code == 302 assert result.headers["location"] == "/adapters" @@ -204,7 +192,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "30", "contact_email": "test@example.com", "region_north": "49.0", @@ -239,14 +229,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") # Should re-render form with error call_args = mock_templates.TemplateResponse.call_args @@ -263,7 +248,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "nonexistent_key", "region_north": "49.5", @@ -299,14 +286,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -322,7 +304,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "feed": "invalid_feed", "region_north": "49.0", @@ -357,14 +341,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "usgs_quake", mock_csrf) + result = await adapters_edit_submit(mock_request, "usgs_quake") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -383,7 +362,9 @@ class TestAdaptersAudit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -410,9 +391,6 @@ class TestAdaptersAudit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -423,7 +401,7 @@ class TestAdaptersAudit: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") assert captured_audit["action"] == "adapter.update" assert captured_audit["target"] == "nws" @@ -449,7 +427,9 @@ class TestAdaptersJsonbRegression: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "test@example.com", "region_north": "49.0", @@ -476,12 +456,9 @@ class TestAdaptersJsonbRegression: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - await adapters_edit_submit(mock_request, "nws", mock_csrf) + await adapters_edit_submit(mock_request, "nws") # Get the settings argument passed to execute (3rd positional arg after query) call_args = mock_conn.execute.call_args @@ -502,7 +479,9 @@ class TestAdaptersJsonbRegression: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -529,9 +508,6 @@ class TestAdaptersJsonbRegression: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -540,7 +516,7 @@ class TestAdaptersJsonbRegression: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - await adapters_edit_submit(mock_request, "nws", mock_csrf) + await adapters_edit_submit(mock_request, "nws") # CRITICAL: before and after must be dicts, NOT strings assert isinstance(captured_audit["before"], dict), f"before should be dict, got {type(captured_audit['before'])}" diff --git a/tests/test_api_keys.py b/tests/test_api_keys.py index 6bd43be..674231b 100644 --- a/tests/test_api_keys.py +++ b/tests/test_api_keys.py @@ -75,13 +75,9 @@ class TestApiKeysListAuthenticated: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_list(mock_request, mock_csrf) + result = await api_keys_list(mock_request) # Check template was called with correct context call_args = mock_templates.TemplateResponse.call_args @@ -104,7 +100,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "test1", "plaintext_key": "secret-api-key-123"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test1", "plaintext_key": "secret-api-key-123"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -119,13 +116,10 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted_data"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/api-keys" @@ -136,7 +130,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "firms", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "firms", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -150,15 +145,10 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted"): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) # Should re-render form with error call_args = mock_templates.TemplateResponse.call_args @@ -172,7 +162,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -183,14 +174,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -203,7 +189,8 @@ class TestApiKeysCreate: mock_request.state.operator = MagicMock(id=1, username="admin") # Test with space - form_data = {"alias": "test key", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test key", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -214,14 +201,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -233,7 +215,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "test-key", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test-key", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -244,14 +227,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -267,7 +245,8 @@ class TestApiKeysRotate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"new_plaintext_key": "new-secret-key-456"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "new_plaintext_key": "new-secret-key-456"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -290,13 +269,10 @@ class TestApiKeysRotate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"new_encrypted"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await api_keys_rotate(mock_request, "test1", mock_csrf) + result = await api_keys_rotate(mock_request, "test1") assert result.status_code == 302 # Check audit was called with no plaintext @@ -313,6 +289,8 @@ class TestApiKeysDelete: """POST /api-keys/{alias}/delete with references shows error.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"}) mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() @@ -331,14 +309,9 @@ class TestApiKeysDelete: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_delete(mock_request, "firms", mock_csrf) + result = await api_keys_delete(mock_request, "firms") # Should re-render with error call_args = mock_templates.TemplateResponse.call_args @@ -351,6 +324,8 @@ class TestApiKeysDelete: """POST /api-keys/{alias}/delete without references deletes and redirects.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"}) mock_conn = AsyncMock() mock_conn.fetchrow.return_value = { @@ -367,12 +342,9 @@ class TestApiKeysDelete: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await api_keys_delete(mock_request, "test1", mock_csrf) + result = await api_keys_delete(mock_request, "test1") assert result.status_code == 302 assert result.headers["location"] == "/api-keys" @@ -388,7 +360,8 @@ class TestApiKeysAuditNoPlaintext: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "newkey", "plaintext_key": "super-secret-value"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "newkey", "plaintext_key": "super-secret-value"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -401,13 +374,10 @@ class TestApiKeysAuditNoPlaintext: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - await api_keys_create(mock_request, mock_csrf) + await api_keys_create(mock_request) # Check audit call arguments call_kwargs = mock_audit.call_args.kwargs diff --git a/tests/test_auth.py b/tests/test_auth.py index 2ea9569..6a36782 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -92,29 +92,33 @@ class TestSessionManagement: mock_conn = MagicMock() mock_conn.execute = AsyncMock() - token, expires_at = await create_session(mock_conn, operator_id=1, lifetime_days=90) + token, expires_at, csrf_token = await create_session(mock_conn, operator_id=1, lifetime_days=90) assert len(token) == 43 + assert len(csrf_token) == 64 # 32 bytes hex = 64 chars mock_conn.execute.assert_called_once() call_args = mock_conn.execute.call_args assert "INSERT INTO config.sessions" in call_args[0][0] @pytest.mark.asyncio async def test_get_session_found(self): - """get_session returns Operator when session exists.""" + """get_session returns (Operator, csrf_token) when session exists.""" mock_conn = MagicMock() mock_conn.fetchrow = AsyncMock(return_value={ "id": 1, "username": "testuser", "created_at": datetime.now(timezone.utc), "password_changed_at": datetime.now(timezone.utc), + "csrf_token": "test_csrf_token_12345", }) - operator = await get_session(mock_conn, "valid-token") + result = await get_session(mock_conn, "valid-token") - assert operator is not None + assert result is not None + operator, csrf_token = result assert operator.id == 1 assert operator.username == "testuser" + assert csrf_token == "test_csrf_token_12345" @pytest.mark.asyncio async def test_get_session_not_found(self): diff --git a/tests/test_config_store.py b/tests/test_config_store.py index 797a221..4653e32 100644 --- a/tests/test_config_store.py +++ b/tests/test_config_store.py @@ -39,6 +39,7 @@ def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> clear_key_cache() monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN) monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path)) + monkeypatch.setenv("CENTRAL_CSRF_SECRET", "test-csrf-secret-for-testing-only-32chars") @pytest_asyncio.fixture diff --git a/tests/test_csrf_handler.py b/tests/test_csrf_handler.py index 58456e3..0b873c3 100644 --- a/tests/test_csrf_handler.py +++ b/tests/test_csrf_handler.py @@ -14,23 +14,18 @@ class TestCsrfExceptionHandlerRegistered: """Verify CSRF exception handler is properly registered.""" def test_csrf_exception_handler_is_registered(self): - """The app has a CsrfProtectError exception handler registered.""" + """The app has a CsrfValidationError exception handler registered.""" from central.gui import app - from fastapi_csrf_protect.exceptions import CsrfProtectError + from central.gui.auth import CsrfValidationError - assert CsrfProtectError in app.exception_handlers, \ - "CsrfProtectError handler should be registered" + assert CsrfValidationError in app.exception_handlers, \ + "CsrfValidationError handler should be registered" - def test_csrf_subclasses_are_caught(self): - """MissingTokenError and TokenValidationError inherit from CsrfProtectError.""" - from fastapi_csrf_protect.exceptions import ( - CsrfProtectError, - MissingTokenError, - TokenValidationError, - ) + def test_csrf_validation_error_is_exception(self): + """CsrfValidationError is a proper Exception subclass.""" + from central.gui.auth import CsrfValidationError - assert issubclass(MissingTokenError, CsrfProtectError) - assert issubclass(TokenValidationError, CsrfProtectError) + assert issubclass(CsrfValidationError, Exception) class TestCsrfExceptionHandlerBehavior: @@ -40,10 +35,10 @@ class TestCsrfExceptionHandlerBehavior: """CSRF handler checks request path for /login.""" import inspect from central.gui import _create_app - from fastapi_csrf_protect.exceptions import CsrfProtectError + from central.gui.auth import CsrfValidationError app = _create_app() - handler = app.exception_handlers.get(CsrfProtectError) + handler = app.exception_handlers.get(CsrfValidationError) # Verify handler source contains /login path check source = inspect.getsource(handler) @@ -54,17 +49,16 @@ class TestCsrfExceptionHandlerBehavior: async def test_logout_csrf_error_redirects_to_login(self): """CSRF error on /logout should redirect to /login.""" from central.gui import _create_app - from fastapi_csrf_protect.exceptions import TokenValidationError + from central.gui.auth import CsrfValidationError from fastapi.responses import RedirectResponse app = _create_app() - from fastapi_csrf_protect.exceptions import CsrfProtectError - handler = app.exception_handlers.get(CsrfProtectError) + handler = app.exception_handlers.get(CsrfValidationError) mock_request = MagicMock() mock_request.url.path = "/logout" - exc = TokenValidationError("Invalid token") + exc = CsrfValidationError("Invalid token") result = await handler(mock_request, exc) @@ -75,17 +69,16 @@ class TestCsrfExceptionHandlerBehavior: async def test_adapters_csrf_error_redirects_to_adapters(self): """CSRF error on /adapters/{name} should redirect to /adapters.""" from central.gui import _create_app - from fastapi_csrf_protect.exceptions import TokenValidationError + from central.gui.auth import CsrfValidationError from fastapi.responses import RedirectResponse app = _create_app() - from fastapi_csrf_protect.exceptions import CsrfProtectError - handler = app.exception_handlers.get(CsrfProtectError) + handler = app.exception_handlers.get(CsrfValidationError) mock_request = MagicMock() mock_request.url.path = "/adapters/nws" - exc = TokenValidationError("Invalid token") + exc = CsrfValidationError("Invalid token") result = await handler(mock_request, exc) @@ -94,16 +87,171 @@ class TestCsrfExceptionHandlerBehavior: class TestCsrfHandlerNoTraceback: - """Verify exception handler doesn't expose Python internals.""" + """Verify exception handler does not expose Python internals.""" def test_handler_exists_and_is_async(self): """The CSRF handler should be an async function.""" import inspect from central.gui import _create_app - from fastapi_csrf_protect.exceptions import CsrfProtectError + from central.gui.auth import CsrfValidationError app = _create_app() - handler = app.exception_handlers.get(CsrfProtectError) + handler = app.exception_handlers.get(CsrfValidationError) assert handler is not None assert inspect.iscoroutinefunction(handler) + + +class TestCsrfHandlerWizardPaths: + """Test CSRF exception handler for wizard paths.""" + + @pytest.mark.asyncio + async def test_setup_operator_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/operator re-renders form with error message.""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/operator" + + exc = CsrfValidationError("Invalid token") + + result = await handler(mock_request, exc) + + # Should be HTML response, not redirect + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_system_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/system re-renders form with error message.""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/system" + + exc = CsrfValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_keys_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/keys re-renders form with error message.""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/keys" + + exc = CsrfValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_adapters_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/adapters re-renders form with error message.""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/adapters" + + exc = CsrfValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_finish_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/finish re-renders form with error message.""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/finish" + + exc = CsrfValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_base_csrf_error_redirects_to_setup(self): + """CSRF error on /setup redirects to /setup (middleware routes to step).""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + from fastapi.responses import RedirectResponse + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/setup" + + exc = CsrfValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + @pytest.mark.asyncio + async def test_login_csrf_error_still_works(self): + """CSRF error on /login still renders login form with error (regression test).""" + from central.gui import _create_app + from central.gui.auth import CsrfValidationError + + app = _create_app() + handler = app.exception_handlers.get(CsrfValidationError) + + mock_request = MagicMock() + mock_request.url.path = "/login" + + exc = CsrfValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() diff --git a/tests/test_csrf_race_condition.py b/tests/test_csrf_race_condition.py new file mode 100644 index 0000000..6235903 --- /dev/null +++ b/tests/test_csrf_race_condition.py @@ -0,0 +1,108 @@ +""" +Integration test for CSRF race condition fix. + +This test verifies that the session-bound CSRF implementation fixes the race +condition where interleaved GET requests would invalidate CSRF tokens. + +See: PR #24 - Central 1b-8 fix-up phase 2 +""" + +import pytest + + +class TestCsrfRaceConditionFix: + """Verify that interleaved GETs don't break CSRF validation.""" + + def test_session_bound_csrf_consistent_across_gets(self): + """Session-bound CSRF tokens remain consistent across multiple GETs. + + This was the core bug: fastapi-csrf-protect rotated tokens on every GET, + causing race conditions when users had multiple tabs or slow connections. + + With session-bound CSRF, the token is stored in the session row and + remains constant until the session is destroyed. + """ + from unittest.mock import MagicMock, AsyncMock + from central.gui.auth import get_session + + # Mock a session with a csrf_token + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={ + "id": 1, + "username": "testuser", + "created_at": "2024-01-01T00:00:00Z", + "password_changed_at": "2024-01-01T00:00:00Z", + "csrf_token": "fixed_csrf_token_12345", + }) + + import asyncio + + async def test(): + # First GET + result1 = await get_session(mock_conn, "test-token") + assert result1 is not None + op1, csrf1 = result1 + + # Second GET (simulating interleaved request) + result2 = await get_session(mock_conn, "test-token") + assert result2 is not None + op2, csrf2 = result2 + + # CSRF tokens should be identical (the fix!) + assert csrf1 == csrf2 == "fixed_csrf_token_12345" + + asyncio.run(test()) + + def test_pre_auth_csrf_tokens_independently_valid(self): + """Pre-auth CSRF tokens are independently valid. + + For unauthenticated routes, each GET generates a new token+cookie pair. + Each pair should validate independently, allowing the original token + to work even if another GET happened in between. + """ + from central.gui.csrf import generate_pre_auth_csrf, validate_pre_auth_csrf + from unittest.mock import MagicMock + + secret = "testsecret12345678901234567890ab" + + # First GET generates token1 + cookie1 + token1, signed1 = generate_pre_auth_csrf(secret) + + # Second GET generates token2 + cookie2 + token2, signed2 = generate_pre_auth_csrf(secret) + + # Tokens should be different (fresh random tokens) + assert token1 != token2 + assert signed1 != signed2 + + # But each pair should validate independently + mock_request1 = MagicMock() + mock_request1.cookies = {"central_preauth_csrf": signed1} + + mock_request2 = MagicMock() + mock_request2.cookies = {"central_preauth_csrf": signed2} + + # Original token still validates with original cookie + assert validate_pre_auth_csrf(mock_request1, token1, secret) is True + + # Second token validates with second cookie + assert validate_pre_auth_csrf(mock_request2, token2, secret) is True + + # Cross-validation should fail + assert validate_pre_auth_csrf(mock_request1, token2, secret) is False + assert validate_pre_auth_csrf(mock_request2, token1, secret) is False + + def test_csrf_token_generation_is_secure(self): + """CSRF tokens are cryptographically secure.""" + from central.gui.auth import generate_csrf_token + + # Generate multiple tokens + tokens = [generate_csrf_token() for _ in range(100)] + + # All tokens should be unique + assert len(set(tokens)) == 100 + + # Tokens should be 64 hex chars (32 bytes) + for token in tokens: + assert len(token) == 64 + assert all(c in "0123456789abcdef" for c in token) diff --git a/tests/test_region_picker.py b/tests/test_region_picker.py index f5a8816..63683ea 100644 --- a/tests/test_region_picker.py +++ b/tests/test_region_picker.py @@ -51,13 +51,9 @@ class TestRegionPickerInTemplate: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "firms", mock_csrf) + result = await adapters_edit_form(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -79,7 +75,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -109,9 +107,6 @@ class TestRegionValidation: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_settings = {} async def capture_execute(query, *args): @@ -122,7 +117,7 @@ class TestRegionValidation: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") assert result.status_code == 302 assert captured_settings["settings"]["region"]["north"] == 45.0 @@ -139,7 +134,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "30.0", # Less than south! @@ -175,14 +172,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -198,7 +190,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -234,14 +228,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -257,7 +246,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "95.0", # > 90! @@ -293,14 +284,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -319,7 +305,9 @@ class TestRegionAuditLog: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -352,9 +340,6 @@ class TestRegionAuditLog: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -363,7 +348,7 @@ class TestRegionAuditLog: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") # Before should have old region assert captured_audit["before"]["settings"]["region"]["north"] == 49.5 diff --git a/tests/test_session_auth.py b/tests/test_session_auth.py index 004e756..30efb50 100644 --- a/tests/test_session_auth.py +++ b/tests/test_session_auth.py @@ -43,6 +43,7 @@ class TestSessionMiddleware: "username": "admin", "created_at": datetime.now(timezone.utc), "password_changed_at": datetime.now(timezone.utc), + "csrf_token": "mock_csrf_token_12345", }) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock() @@ -99,6 +100,7 @@ class TestSessionMiddleware: "username": "admin", "created_at": datetime.now(timezone.utc), "password_changed_at": datetime.now(timezone.utc), + "csrf_token": "mock_csrf_token_12345", }) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) mock_conn.__aexit__ = AsyncMock() diff --git a/tests/test_setup_gate.py b/tests/test_setup_gate.py index 9aa11ce..7138162 100644 --- a/tests/test_setup_gate.py +++ b/tests/test_setup_gate.py @@ -12,8 +12,8 @@ class TestSetupGateMiddleware: """Tests for SetupGateMiddleware.""" @pytest.mark.asyncio - async def test_allows_setup_route_when_incomplete(self): - """SetupGateMiddleware allows /setup when setup_complete=False.""" + async def test_allows_setup_subpath_when_incomplete(self): + """SetupGateMiddleware allows /setup/operator when setup_complete=False.""" mock_pool = MagicMock() mock_conn = MagicMock() mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) @@ -21,6 +21,31 @@ class TestSetupGateMiddleware: mock_conn.__aexit__ = AsyncMock() mock_pool.acquire = MagicMock(return_value=mock_conn) + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup/operator") + assert response.status_code == 200 + assert response.json() == {"message": "operator"} + + @pytest.mark.asyncio + async def test_redirects_setup_base_to_wizard_step(self): + """SetupGateMiddleware redirects /setup to wizard step when incomplete.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.fetchval = AsyncMock(return_value=0) # No operators + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + with patch("central.gui.middleware.get_pool", return_value=mock_pool): app = FastAPI() @@ -28,12 +53,16 @@ class TestSetupGateMiddleware: async def setup(): return {"message": "setup"} + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + app.add_middleware(SetupGateMiddleware) - client = TestClient(app) + client = TestClient(app, follow_redirects=False) response = client.get("/setup") - assert response.status_code == 200 - assert response.json() == {"message": "setup"} + assert response.status_code == 302 + assert response.headers["location"] == "/setup/operator" @pytest.mark.asyncio async def test_allows_health_when_incomplete(self): @@ -135,7 +164,7 @@ class TestSetupGateMiddleware: @pytest.mark.asyncio async def test_redirects_setup_when_complete(self): - """SetupGateMiddleware redirects /setup to / when setup_complete=True.""" + """SetupGateMiddleware redirects /setup/* to / when setup_complete=True.""" mock_pool = MagicMock() mock_conn = MagicMock() mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) @@ -154,9 +183,18 @@ class TestSetupGateMiddleware: async def setup(): return {"message": "setup"} + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + app.add_middleware(SetupGateMiddleware) client = TestClient(app, follow_redirects=False) + # Both /setup and /setup/operator should redirect to / response = client.get("/setup") assert response.status_code == 302 assert response.headers["location"] == "/" + + response = client.get("/setup/operator") + assert response.status_code == 302 + assert response.headers["location"] == "/" diff --git a/tests/test_streams.py b/tests/test_streams.py index c2346fa..528e0ed 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -52,10 +52,6 @@ class TestStreamsListAuthenticated: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with proper state fields mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -74,7 +70,7 @@ class TestStreamsListAuthenticated: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -117,14 +113,10 @@ class TestStreamsListNatsUnavailable: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -157,10 +149,6 @@ class TestStreamsListPartialFailure: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream - CENTRAL_FIRE raises ValueError, CENTRAL_WX works mock_js = AsyncMock() test_ts = datetime(2026, 5, 17, 12, 0, 0, tzinfo=timezone.utc) @@ -184,7 +172,7 @@ class TestStreamsListPartialFailure: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -222,10 +210,6 @@ class TestStreamsListEmptyStream: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with empty stream (first_seq = 0) mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -239,7 +223,7 @@ class TestStreamsListEmptyStream: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -278,10 +262,6 @@ class TestStreamsListSingleMessage: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with single message (first_seq == last_seq) mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -299,7 +279,7 @@ class TestStreamsListSingleMessage: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -337,10 +317,6 @@ class TestStreamsListGetMsgFailure: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -365,7 +341,7 @@ class TestStreamsListGetMsgFailure: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -394,8 +370,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "1209600" # 14 days + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "1209600", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -406,9 +386,6 @@ class TestStreamsUpdate: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -419,7 +396,7 @@ class TestStreamsUpdate: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") assert result.status_code == 302 assert result.headers["location"] == "/streams" @@ -438,8 +415,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "60" # 1 minute - too small + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "60", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -458,15 +439,10 @@ class TestStreamsUpdate: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -480,9 +456,10 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "999999999" # Way too large + mock_form.get.side_effect = lambda k, d="": {"csrf_token": "test_csrf_token", "max_age_s": "999999999"}.get(k, d) # Way too large mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -501,15 +478,10 @@ class TestStreamsUpdate: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -523,8 +495,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "604800" + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "604800", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -534,11 +510,8 @@ class TestStreamsUpdate: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await streams_update(mock_request, "nonexistent", mock_csrf) + result = await streams_update(mock_request, "nonexistent") assert result.status_code == 404 @@ -554,8 +527,12 @@ class TestStreamsAudit: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "1209600" # 14 days + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "1209600", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -566,9 +543,6 @@ class TestStreamsAudit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -580,7 +554,7 @@ class TestStreamsAudit: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - await streams_update(mock_request, "CENTRAL_QUAKE", mock_csrf) + await streams_update(mock_request, "CENTRAL_QUAKE") assert captured_audit["action"] == "stream.update" assert captured_audit["operator_id"] == 1 diff --git a/tests/test_wizard.py b/tests/test_wizard.py new file mode 100644 index 0000000..dcaa7fe --- /dev/null +++ b/tests/test_wizard.py @@ -0,0 +1,665 @@ +"""Tests for the first-run setup wizard.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from central.gui.routes import ( + setup_operator_form, + setup_operator_submit, + setup_system_form, + setup_system_submit, + setup_keys_form, + setup_keys_submit, + setup_adapters_form, + setup_adapters_submit, + setup_finish_form, + setup_finish_submit, +) +from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step + + +class TestWizardStepRedirect: + """Test wizard step redirect logic.""" + + @pytest.mark.asyncio + async def test_no_operators_redirects_to_operator(self): + """When no operators exist, redirect to /setup/operator.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [0] # No operators + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/operator" + + @pytest.mark.asyncio + async def test_default_tile_url_redirects_to_system(self): + """When map_tile_url is default, redirect to /setup/system.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1] # Has operator + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/system" + + @pytest.mark.asyncio + async def test_no_adapters_touched_redirects_to_keys(self): + """When no adapters have been updated, redirect to /setup/keys.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/keys" + + @pytest.mark.asyncio + async def test_all_steps_complete_redirects_to_finish(self): + """When all steps done, redirect to /setup/finish.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/finish" + + +class TestSetupOperatorForm: + """Test operator creation form (step 1).""" + + @pytest.mark.asyncio + async def test_get_returns_form(self): + """GET /setup/operator returns the form when no operator exists.""" + mock_request = MagicMock() + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = None # No operator exists + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) + + mock_templates.TemplateResponse.assert_called_once() + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert "csrf_token" in context and context["csrf_token"] + assert context["error"] is None + assert context["existing_operator"] is None + + @pytest.mark.asyncio + async def test_get_returns_confirmation_when_operator_exists(self): + """GET /setup/operator shows confirmation when operator already exists.""" + mock_request = MagicMock() + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.body = b"Operator Already Configured" + mock_templates.TemplateResponse.return_value = mock_response + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) + + mock_templates.TemplateResponse.assert_called_once() + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["existing_operator"] == {"username": "admin"} + assert context["error"] is None + + +class TestSetupOperatorSubmit: + """Test operator creation submission.""" + + @pytest.mark.asyncio + async def test_password_mismatch_shows_error(self): + """POST with password mismatch re-renders with error.""" + mock_request = MagicMock() + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password1", + "confirm_password": "password2", # Mismatch + }) + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 0 # No existing operators + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password1", + confirm_password="password2", + ) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["error"] == "Passwords do not match" + + @pytest.mark.asyncio + async def test_valid_creates_operator_and_redirects(self): + """POST with valid data creates operator and redirects to /setup/system.""" + mock_request = MagicMock() + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) + + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 0 # No existing operators + mock_conn.fetchrow.side_effect = [ + {"id": 1}, # INSERT RETURNING id + {"session_lifetime_days": 90}, # system settings + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.hash_password", return_value="hashed"): + with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: + mock_session.return_value = ("session_token", datetime.now(), "csrf_token") + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) + + assert result.status_code == 302 + assert result.headers["location"] == "/setup/system" + + @pytest.mark.asyncio + async def test_post_when_operator_exists_shows_confirmation(self): + """POST when operator exists returns 200 with confirmation, no insert.""" + mock_request = MagicMock() + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 1 # Operator already exists + mock_conn.fetchrow.return_value = {"username": "existing_admin"} + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_request.state.csrf_token = "test_csrf" + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) + + # Should return 200, not 500 or redirect + assert result.status_code == 200 + + # Should render confirmation state + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["existing_operator"] == {"username": "existing_admin"} + + # Should NOT call write_audit (no insert happened) + mock_audit.assert_not_called() + + +class TestSetupSystemForm: + """Test system settings form (step 2).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/system without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_system_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_authenticated_returns_form(self): + """GET /setup/system with auth returns the form.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "© OpenStreetMap contributors", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_form(mock_request) + + mock_templates.TemplateResponse.assert_called_once() + + +class TestSetupSystemSubmit: + """Test system settings submission.""" + + @pytest.mark.asyncio + async def test_missing_placeholders_shows_error(self): + """POST without {z},{x},{y} placeholders shows error.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "map_tile_url": "https://example.com/tiles", + "map_attribution": "Test", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "", + "map_attribution": "", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_submit(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert "map_tile_url" in context["errors"] + + @pytest.mark.asyncio + async def test_valid_updates_and_redirects(self): + """POST with valid data updates system and redirects to /setup/keys.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "map_tile_url": "https://example.com/{z}/{x}/{y}.png", + "map_attribution": "Test Attribution", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "old_url", + "map_attribution": "old_attr", + } + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_system_submit(mock_request) + + assert result.status_code == 302 + assert result.headers["location"] == "/setup/keys" + + +class TestSetupKeysForm: + """Test API keys form (step 3).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/keys without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_keys_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupKeysSubmit: + """Test API keys submission.""" + + @pytest.mark.asyncio + async def test_next_action_redirects_to_adapters(self): + """POST with action=next redirects to /setup/adapters.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "action": "next", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + # No need to mock get_pool since action="next" returns before it's called + result = await setup_keys_submit(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/adapters" + + @pytest.mark.asyncio + async def test_add_key_creates_and_rerenders(self): + """POST with action=add creates key and re-renders with success.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "action": "add", + "alias": "testkey", + "plaintext_key": "secret123", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.side_effect = [ + None, # No existing key + {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, + ] + mock_conn.fetch.side_effect = [ + [], # First list + [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.crypto.encrypt", return_value=b"encrypted"): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_keys_submit(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["success"] == "API key 'testkey' added successfully." + + +class TestSetupAdaptersForm: + """Test adapters configuration form (step 4).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/adapters without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_adapters_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupFinishForm: + """Test finish page (step 5).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/finish without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_finish_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_authenticated_shows_summary(self): + """GET /setup/finish with auth shows summary.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys + mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} + mock_conn.fetch.return_value = [ + {"name": "nws", "enabled": True, "cadence_s": 300}, + {"name": "firms", "enabled": False, "cadence_s": 600}, + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_finish_form(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["operator_count"] == 1 + assert context["key_count"] == 2 + assert len(context["adapters"]) == 2 + + +class TestSetupFinishSubmit: + """Test setup completion.""" + + @pytest.mark.asyncio + async def test_marks_setup_complete_and_redirects(self): + """POST /setup/finish marks setup_complete=true and redirects to /.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + # Mock form with CSRF token + form_data = MagicMock() + form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_finish_submit(mock_request) + + assert result.status_code == 302 + assert result.headers["location"] == "/" + mock_conn.execute.assert_called_once() + mock_audit.assert_called_once() + + +class TestSetupGateMiddlewareWizard: + """Test SetupGateMiddleware with wizard paths.""" + + @pytest.mark.asyncio + async def test_allows_setup_operator_when_incomplete(self): + """SetupGateMiddleware allows /setup/operator when setup_complete=False.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator form"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup/operator") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_base_setup_to_wizard_step(self): + """SetupGateMiddleware redirects /setup to appropriate wizard step.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.fetchval = AsyncMock(return_value=0) # No operators + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup") + async def setup(): + return {"message": "base setup"} + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup") + assert response.status_code == 302 + assert response.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_redirects_login_to_setup_when_incomplete(self): + """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/login") + async def login(): + return {"message": "login"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/login") + assert response.status_code == 302 + assert response.headers["location"] == "/setup" + + @pytest.mark.asyncio + async def test_redirects_all_setup_paths_when_complete(self): + """SetupGateMiddleware redirects /setup/* to / when setup_complete=True.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup/operator") + assert response.status_code == 302 + assert response.headers["location"] == "/"