From 78b6fcf150cb49a2e9808d66e36dd3df1bd1063f Mon Sep 17 00:00:00 2001 From: malice Date: Mon, 18 May 2026 08:18:04 -0600 Subject: [PATCH] 1b-8: Wizard redesign (deferred-commit) + map fixes + favicon CSRF race fix (#27) * feat(wizard): implement deferred-commit pattern for setup wizard Replace the current "POST each step -> DB write -> redirect" architecture with "collect values across steps in a signed cookie, commit everything in one transaction at Finish." Key changes: - Add wizard.py: WizardState dataclass and cookie helpers - csrf.py: Add reuse_or_generate_pre_auth_csrf helper - routes.py: All wizard handlers now use cookie state, no DB writes until finish - middleware.py: Cookie-based wizard step routing instead of DB queries - setup_operator.html: Remove "Operator Already Configured" branch Benefits: - Back navigation works: can return to any step and edit values - Atomic commit: all DB writes happen in single transaction at finish - No orphaned state: failed wizard leaves no DB artifacts - Simpler auth: pre-auth CSRF for all 5 steps (no session until finish) Tests updated for new behavior. 287 tests passing. Co-Authored-By: Claude Opus 4.5 * fix(templates): correct SRI hashes for leaflet.draw assets The integrity hashes for leaflet.draw.css and leaflet.draw.js were incorrect, causing browsers to silently block these resources. This broke the Leaflet.draw toolbar and map rendering for FIRMS/USGS adapter region pickers. Updated both setup_adapters.html and adapters_edit.html with the correct sha512 hashes computed from the actual CDN files. Co-Authored-By: Claude Opus 4.5 * fix(gui): return 204 for browser-noise paths to prevent CSRF races Browser requests for /favicon.ico, /apple-touch-icon.png, etc. were triggering parallel GET requests that could race with form loads, causing CSRF token rotation issues. Added BROWSER_NOISE_PATHS constant and early 204 response in both SetupGateMiddleware and SessionMiddleware to short-circuit these requests before any cookie/token handling occurs. Co-Authored-By: Claude Opus 4.5 --------- Co-authored-by: Matt Johnson Co-authored-by: Claude Opus 4.5 --- src/central/gui/csrf.py | 31 +- src/central/gui/middleware.py | 65 +- src/central/gui/routes.py | 985 +++++++++--------- src/central/gui/templates/adapters_edit.html | 4 +- src/central/gui/templates/setup_adapters.html | 4 +- src/central/gui/templates/setup_operator.html | 12 - src/central/gui/wizard.py | 131 +++ tests/test_wizard.py | 576 +--------- 8 files changed, 726 insertions(+), 1082 deletions(-) create mode 100644 src/central/gui/wizard.py diff --git a/src/central/gui/csrf.py b/src/central/gui/csrf.py index 0d6198f..37848cd 100644 --- a/src/central/gui/csrf.py +++ b/src/central/gui/csrf.py @@ -1,11 +1,10 @@ -"""Pre-auth CSRF protection for login and setup/operator pages. +"""Pre-auth CSRF protection for login and setup pages. These routes cannot use session-bound CSRF because no session exists yet. Uses a simple cookie-based pattern with short-lived tokens. """ import secrets -from typing import Optional from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from starlette.requests import Request @@ -34,6 +33,34 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]: return plain_token, signed_token +def reuse_or_generate_pre_auth_csrf( + request: Request, + secret_key: str, +) -> tuple[str, str | None]: + """Reuse an existing valid pre-auth CSRF token, or generate new. + + Returns (plain_token, signed_token_for_cookie). + If signed_token_for_cookie is None, the existing cookie is + still valid and caller should not call set_pre_auth_csrf_cookie. + If non-None, caller MUST call set_pre_auth_csrf_cookie with + it to persist the new value. + """ + cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE) + if cookie_value: + serializer = _get_serializer(secret_key) + try: + plain_token = serializer.loads( + cookie_value, + max_age=PRE_AUTH_CSRF_MAX_AGE, + ) + return plain_token, None # reuse existing + except (BadSignature, SignatureExpired): + pass # fall through to generate + + plain_token, signed_token = generate_pre_auth_csrf(secret_key) + return plain_token, signed_token + + def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None: """Set the pre-auth CSRF cookie on a response.""" response.set_cookie( diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index 155112b..2af6230 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -16,7 +16,15 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} -AUTH_EXEMPT_PREFIXES = ("/static/",) +AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/") + +# Browser-noise paths that trigger CSRF race conditions +BROWSER_NOISE_PATHS = { + "/favicon.ico", + "/apple-touch-icon.png", + "/apple-touch-icon-precomposed.png", + "/robots.txt", +} def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: @@ -29,33 +37,14 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False -async def _get_wizard_redirect_step(conn) -> str: - """Determine which wizard step to redirect to based on DB state.""" - # Check if any operators exist - op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - if op_count == 0: +def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str: + """Determine wizard redirect step from cookie state.""" + from central.gui.wizard import get_wizard_state, get_step_route + + state = get_wizard_state(request, csrf_secret) + if state is None: return "/setup/operator" - - # Check if system settings have been configured (map_tile_url not default) - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - if sys_row is None or sys_row["map_tile_url"] == default_tile: - return "/setup/system" - - # Keys step is optional, so check adapters have been reviewed - # We consider adapters reviewed if any adapter has a non-null updated_at - # (meaning it was explicitly saved during setup) - adapters_touched = await conn.fetchval( - "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" - ) - if adapters_touched == 0: - # Go to keys first, then adapters - return "/setup/keys" - - # All steps done, go to finish - return "/setup/finish" + return get_step_route(state.wizard_step) class SetupGateMiddleware(BaseHTTPMiddleware): @@ -64,6 +53,10 @@ class SetupGateMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path + # Short-circuit browser-noise requests that cause CSRF races + if path in BROWSER_NOISE_PATHS: + return Response(status_code=204) + # Check setup status from database pool = get_pool() if pool is None: @@ -85,13 +78,16 @@ class SetupGateMiddleware(BaseHTTPMiddleware): if not setup_complete: # Setup not complete - only allow setup paths and static/health if path.startswith("/setup"): - # Allow all /setup/* paths (handler will enforce auth) + # Allow all /setup/* paths # But /setup with no subpath should redirect to appropriate step if path == "/setup" or path == "/setup/": try: - async with pool.acquire() as conn: - redirect_step = await _get_wizard_redirect_step(conn) - return RedirectResponse(url=redirect_step, status_code=302) + from central.bootstrap_config import get_settings + settings = get_settings() + redirect_step = _get_wizard_redirect_from_cookie( + request, settings.csrf_secret + ) + return RedirectResponse(url=redirect_step, status_code=302) except Exception: logger.warning("Failed to determine wizard step", exc_info=True) return RedirectResponse(url="/setup/operator", status_code=302) @@ -118,6 +114,11 @@ class SessionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path + # Short-circuit browser-noise requests (already handled by SetupGateMiddleware, + # but this protects if middleware order changes) + if path in BROWSER_NOISE_PATHS: + return Response(status_code=204) + # Initialize state request.state.operator = None request.state.csrf_token = None @@ -139,7 +140,7 @@ class SessionMiddleware(BaseHTTPMiddleware): request.state.operator = None request.state.csrf_token = None - # Check if auth is required + # Check if auth is required - setup paths are exempt during wizard if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): if request.state.operator is None: return RedirectResponse(url="/login", status_code=302) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 0d08e77..f86b6e1 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -13,6 +13,7 @@ from fastapi import APIRouter, Depends, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response from central.bootstrap_config import get_settings from central.gui.csrf import ( + reuse_or_generate_pre_auth_csrf, generate_pre_auth_csrf, set_pre_auth_csrf_cookie, validate_pre_auth_csrf, @@ -49,6 +50,9 @@ router = APIRouter() # Streams to display on dashboard DASHBOARD_STREAMS = ["CENTRAL_WX", "CENTRAL_FIRE", "CENTRAL_QUAKE", "CENTRAL_META"] +# Email validation regex (simple but effective) +ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$") + # Email validation regex (simple but effective) EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") @@ -267,24 +271,29 @@ async def dashboard_polls(request: Request) -> HTMLResponse: # ============================================================================= -@router.get("/setup/operator", response_class=HTMLResponse) -async def setup_operator_form( - request: Request, -) -> HTMLResponse: - """Render the setup operator form (step 1).""" - templates = _get_templates() - pool = get_pool() - settings = get_settings() - csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) +# ============================================================================= +# Setup Wizard routes (deferred-commit pattern) +# ============================================================================= - # Check if operator already exists - existing_operator = None - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - if row: - existing_operator = {"username": row["username"]} + +@router.get("/setup/operator", response_class=HTMLResponse) +async def setup_operator_form(request: Request) -> HTMLResponse: + """Render the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + templates = _get_templates() + settings = get_settings() + + # Get wizard state from cookie (if any) + state = get_wizard_state(request, settings.csrf_secret) + + # Pre-fill from cookie state if available + form_data = None + if state and state.operator: + form_data = {"username": state.operator.get("username", "")} + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, @@ -292,11 +301,11 @@ async def setup_operator_form( context={ "csrf_token": csrf_token, "error": None, - "form_data": None, - "existing_operator": existing_operator, + "form_data": form_data, }, ) - set_pre_auth_csrf_cookie(response, signed_token) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @@ -306,39 +315,22 @@ async def setup_operator_submit( username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), - ) -> Response: """Process the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state, set_wizard_cookie, WizardState + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + templates = _get_templates() - pool = get_pool() + settings = get_settings() # Validate CSRF - settings = get_settings() form = await request.form() form_csrf = form.get("csrf_token", "") if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - # Check if operator already exists (single-operator-per-install design) - async with pool.acquire() as conn: - count = await conn.fetchval("SELECT count(*) FROM config.operators") - if count > 0: - # Operator already exists — render confirmation page - existing = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_operator.html", - context={ - "csrf_token": csrf_token, - "error": None, - "form_data": None, - "existing_operator": {"username": existing["username"]}, - }, - ) - return response + # Get or create wizard state + state = get_wizard_state(request, settings.csrf_secret) or WizardState() # Validate input error = None @@ -351,7 +343,7 @@ async def setup_operator_submit( error = str(e) if error: - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_operator.html", @@ -359,73 +351,54 @@ async def setup_operator_submit( "csrf_token": csrf_token, "error": error, "form_data": {"username": username}, - "existing_operator": None, }, status_code=200, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response - # Create operator + # Hash password and store in wizard state (NO DB write) password_hash = hash_password(password) - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - INSERT INTO config.operators (username, password_hash) - VALUES ($1, $2) - RETURNING id - """, - username, - password_hash, - ) - operator_id = row["id"] + state.operator = {"username": username, "password_hash": password_hash} + state.wizard_step = max(state.wizard_step, 2) - # Write audit log - await write_audit( - conn, - OPERATOR_CREATE, - operator_id=operator_id, - target=username, - ) - - # Get session lifetime - sysrow = await conn.fetchrow( - "SELECT session_lifetime_days FROM config.system WHERE id = true" - ) - lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 - - # Create session - token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) - - # Redirect to next step with session cookie + # Redirect to next step with updated wizard cookie response = RedirectResponse(url="/setup/system", status_code=302) - _set_session_cookie(response, token, lifetime_days * 86400) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/system", response_class=HTMLResponse) -async def setup_system_form( - request: Request, - -) -> HTMLResponse: +async def setup_system_form(request: Request) -> HTMLResponse: """Render the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required for step 2+ + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", - } + # Pre-fill from cookie state or DB defaults + if state.system: + system = state.system + else: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_system.html", @@ -437,29 +410,31 @@ async def setup_system_form( "system": system, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/system") -async def setup_system_submit( - request: Request, - -) -> Response: +async def setup_system_submit(request: Request) -> Response: """Process the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() - pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() map_tile_url = form.get("map_tile_url", "").strip() map_attribution = form.get("map_attribution", "").strip() @@ -480,87 +455,52 @@ async def setup_system_submit( if not map_attribution: errors["map_attribution"] = "Map attribution is required" - async with pool.acquire() as conn: - if errors: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "", - "map_attribution": row["map_attribution"] if row else "", - } - - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_system.html", - context={ - "csrf_token": csrf_token, - "error": None, - "errors": errors, - "form_data": form_data, - "system": system, - }, - status_code=200, - ) - return response - - # Get current values for audit - old_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": state.system or form_data, + }, + status_code=200, ) - before = { - "map_tile_url": old_row["map_tile_url"] if old_row else None, - "map_attribution": old_row["map_attribution"] if old_row else None, - } + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - # Update system settings - await conn.execute( - """ - UPDATE config.system - SET map_tile_url = $1, map_attribution = $2 - WHERE id = true - """, - map_tile_url, - map_attribution, - ) + # Update wizard state (NO DB write) + state.system = {"map_tile_url": map_tile_url, "map_attribution": map_attribution} + state.wizard_step = max(state.wizard_step, 3) - # Write audit log - await write_audit( - conn, - SYSTEM_UPDATE, - operator_id=operator.id, - target="system", - before=before, - after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, - ) - - return RedirectResponse(url="/setup/keys", status_code=302) + response = RedirectResponse(url="/setup/keys", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/keys", response_class=HTMLResponse) -async def setup_keys_form( - request: Request, - -) -> HTMLResponse: +async def setup_keys_form(request: Request) -> HTMLResponse: """Render the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) - from central.crypto import encrypt - templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + # Keys come from cookie state (not DB) + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -572,36 +512,40 @@ async def setup_keys_form( "success": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/keys") -async def setup_keys_submit( - request: Request, - -) -> Response: +async def setup_keys_submit(request: Request) -> Response: """Process the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) - - form = await request.form() - form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: - raise CsrfValidationError("Invalid CSRF token") - - form = await request.form() - action = form.get("action", "add") - - # If action is "next", redirect to adapters step - if action == "next": - return RedirectResponse(url="/setup/adapters", status_code=302) - + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf from central.crypto import encrypt templates = _get_templates() - pool = get_pool() + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + raise CsrfValidationError("Invalid CSRF token") + + action = form.get("action", "add") + + # If action is "next", advance to adapters step + if action == "next": + state.wizard_step = max(state.wizard_step, 4) + response = RedirectResponse(url="/setup/adapters", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response # Otherwise, add a new key alias = form.get("alias", "").strip() @@ -617,6 +561,8 @@ async def setup_keys_submit( errors["alias"] = "Alias must be at most 64 characters" elif not ALIAS_REGEX.match(alias): errors["alias"] = "Alias must contain only letters, numbers, and underscores" + elif any(k["alias"] == alias for k in state.api_keys): + errors["alias"] = "An API key with this alias already exists" # Validate plaintext_key if not plaintext_key: @@ -624,69 +570,34 @@ async def setup_keys_submit( elif len(plaintext_key) > 4096: errors["plaintext_key"] = "API key must be at most 4096 characters" - async with pool.acquire() as conn: - if not errors: - # Check if alias already exists - existing = await conn.fetchrow( - "SELECT alias FROM config.api_keys WHERE alias = $1", - alias, - ) - if existing: - errors["alias"] = "An API key with this alias already exists" + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - if errors: - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_keys.html", - context={ - "csrf_token": csrf_token, - "keys": keys, - "errors": errors, - "form_data": form_data, - "success": None, - }, - status_code=200, - ) - return response - - # Encrypt the key - encrypted_value = encrypt(plaintext_key.encode()) - - # Insert the new key - row = await conn.fetchrow( - """ - INSERT INTO config.api_keys (alias, encrypted_value) - VALUES ($1, $2) - RETURNING created_at - """, - alias, - encrypted_value, - ) - - # Write audit log (no plaintext!) - await write_audit( - conn, - API_KEY_CREATE, - operator_id=operator.id, - target=alias, - before=None, - after={"alias": alias, "created_at": row["created_at"].isoformat()}, - ) - - # Refresh keys list - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + # Encrypt the key and add to state (NO DB write) + encrypted_value = encrypt(plaintext_key.encode()) + encrypted_b64 = base64.b64encode(encrypted_value).decode() + state.api_keys.append({"alias": alias, "encrypted_value_b64": encrypted_b64}) # Re-render with success message - csrf_token = request.state.csrf_token + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -698,61 +609,82 @@ async def setup_keys_submit( "success": f"API key '{alias}' added successfully.", }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/adapters", response_class=HTMLResponse) -async def setup_adapters_form( - request: Request, - -) -> HTMLResponse: +async def setup_adapters_form(request: Request) -> HTMLResponse: """Render the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) + # Pre-fill from cookie state or DB defaults + if state.adapters: adapters = [] - for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - "settings": settings, - }) + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({ + "name": name, + "enabled": a["enabled"], + "cadence_s": a["cadence_s"], + "settings": a["settings"], + }) + else: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + adapters = [] + for row in rows: + settings_data = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings_data, + }) - # Get API keys for dropdown - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + # Get API keys from wizard state (not DB) + api_keys = [{"alias": k["alias"]} for k in state.api_keys] - # Get map tile settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Get map tile settings from wizard state or DB + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + async with pool.acquire() as conn: + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_adapters.html", context={ "csrf_token": csrf_token, "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], + "api_keys": api_keys, "valid_satellites": _get_valid_satellites(), "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, @@ -762,242 +694,216 @@ async def setup_adapters_form( "form_data": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/adapters") -async def setup_adapters_submit( - request: Request, - -) -> Response: +async def setup_adapters_submit(request: Request) -> Response: """Process the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() errors: dict[str, str] = {} + new_adapters: dict[str, dict] = {} - async with pool.acquire() as conn: - # Get current adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) - - for row in rows: - adapter_name = row["name"] - current_settings = row["settings"] or {} - new_settings = dict(current_settings) - - # Parse enabled - enabled = f"{adapter_name}_enabled" in form - - # Parse cadence - cadence_str = form.get(f"{adapter_name}_cadence_s", "") - try: - cadence_s = int(cadence_str) - if cadence_s < 60 or cadence_s > 3600: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" - except ValueError: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" - cadence_s = row["cadence_s"] - - # Adapter-specific validation - if adapter_name == "nws": - contact_email = form.get(f"{adapter_name}_contact_email", "").strip() - if enabled: - if not contact_email: - errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(contact_email): - errors[f"{adapter_name}_contact_email"] = "Invalid email format" - else: - new_settings["contact_email"] = contact_email - else: - new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") - - elif adapter_name == "firms": - api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() - satellites = form.getlist(f"{adapter_name}_satellites") - - if api_key_alias: - key_exists = await conn.fetchrow( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - api_key_alias, - ) - if not key_exists: - errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" - else: - new_settings["api_key_alias"] = api_key_alias - else: - new_settings["api_key_alias"] = None - - # Validate satellites - valid_sats = set(_get_valid_satellites()) - invalid_sats = [s for s in satellites if s not in valid_sats] - if invalid_sats: - errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" - else: - new_settings["satellites"] = satellites - - elif adapter_name == "usgs_quake": - feed = form.get(f"{adapter_name}_feed", "").strip() - valid_feeds = _get_valid_feeds() - if feed not in valid_feeds: - errors[f"{adapter_name}_feed"] = f"Invalid feed" - else: - new_settings["feed"] = feed - - # Region validation - region_north_str = form.get(f"{adapter_name}_region_north", "").strip() - region_south_str = form.get(f"{adapter_name}_region_south", "").strip() - region_east_str = form.get(f"{adapter_name}_region_east", "").strip() - region_west_str = form.get(f"{adapter_name}_region_west", "").strip() - - try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) - - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" - else: - new_settings["region"] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" - - # Store parsed data for re-render on error or update - if not errors.get(f"{adapter_name}_cadence_s"): - # Update adapter - await conn.execute( - """ - UPDATE config.adapters - SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() - WHERE name = $4 - """, - enabled, - cadence_s, - new_settings, - adapter_name, - ) - - # If any errors, re-render - if errors: - adapters = [] + # Get current adapter configs from state or DB as baseline + if state.adapters: + current_adapters = state.adapters + else: + async with pool.acquire() as conn: rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ + "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" ) + current_adapters = {} for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], + current_adapters[row["name"]] = { "enabled": row["enabled"], "cadence_s": row["cadence_s"], - "settings": settings, - }) + "settings": row["settings"] or {}, + } - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + for adapter_name in ["firms", "nws", "usgs_quake"]: + current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) + current_settings = current.get("settings", {}) + new_settings = dict(current_settings) - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Parse enabled + enabled = f"{adapter_name}_enabled" in form - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_adapters.html", - context={ - "csrf_token": csrf_token, - "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), - "tile_url": tile_url, - "tile_attribution": tile_attribution, - "error": "Please fix the errors below.", - "errors": errors, - "form_data": form, - }, - status_code=200, - ) - return response + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = current.get("cadence_s", 300) - return RedirectResponse(url="/setup/finish", status_code=302) + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + # Validate against wizard state keys + if not any(k["alias"] == api_key_alias for k in state.api_keys): + errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = "Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation (all adapters) + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + new_adapters[adapter_name] = { + "enabled": enabled, + "cadence_s": cadence_s, + "settings": new_settings, + } + + # If errors, re-render + if errors: + adapters = [ + {"name": name, "enabled": new_adapters[name]["enabled"], + "cadence_s": new_adapters[name]["cadence_s"], + "settings": new_adapters[name]["settings"]} + for name in ["firms", "nws", "usgs_quake"] + ] + api_keys = [{"alias": k["alias"]} for k in state.api_keys] + + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = "© OpenStreetMap contributors" + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": api_keys, + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + # Update wizard state (NO DB write) + state.adapters = new_adapters + state.wizard_step = max(state.wizard_step, 5) + + response = RedirectResponse(url="/setup/finish", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/finish", response_class=HTMLResponse) -async def setup_finish_form( - request: Request, - -) -> HTMLResponse: +async def setup_finish_form(request: Request) -> HTMLResponse: """Render the finish setup page (step 5).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - # Get counts - operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + operator_count = 1 if state.operator else 0 + key_count = len(state.api_keys) + system = state.system or {"map_tile_url": "(not configured)"} - # Get system settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": sys_row["map_tile_url"] if sys_row else "", - } + adapters = [] + if state.adapters: + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) - # Get adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s - FROM config.adapters - ORDER BY name - """ - ) - adapters = [ - { - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - } - for row in rows - ] - - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_finish.html", @@ -1007,46 +913,101 @@ async def setup_finish_form( "key_count": key_count, "system": system, "adapters": adapters, + "error": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/finish") -async def setup_finish_submit( - request: Request, - -) -> Response: - """Complete the setup wizard.""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) +async def setup_finish_submit(request: Request) -> Response: + """Complete the setup wizard - atomic commit of all wizard state.""" + from central.gui.wizard import get_wizard_state, clear_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + from asyncpg.exceptions import UniqueViolationError + templates = _get_templates() pool = get_pool() + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - async with pool.acquire() as conn: - # Mark setup complete - await conn.execute( - "UPDATE config.system SET setup_complete = true WHERE id = true" - ) + if not state.system: + return RedirectResponse(url="/setup/system", status_code=302) + if not state.adapters: + return RedirectResponse(url="/setup/adapters", status_code=302) - # Write audit log - await write_audit( - conn, - SETUP_COMPLETE, - operator_id=operator.id, - target="system", - ) + try: + async with pool.acquire() as conn: + async with conn.transaction(): + # 1. INSERT operator + op_row = await conn.fetchrow( + "INSERT INTO config.operators (username, password_hash) VALUES ($1, $2) RETURNING id", + state.operator["username"], + state.operator["password_hash"], + ) + operator_id = op_row["id"] - return RedirectResponse(url="/", status_code=302) + await write_audit(conn, OPERATOR_CREATE, operator_id=operator_id, target=state.operator["username"]) + # 2. Create session + sysrow = await conn.fetchrow("SELECT session_lifetime_days FROM config.system WHERE id = true") + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) + # 3. UPDATE config.system + old_sys = await conn.fetchrow("SELECT map_tile_url, map_attribution FROM config.system WHERE id = true") + await conn.execute( + "UPDATE config.system SET map_tile_url = $1, map_attribution = $2, setup_complete = true WHERE id = true", + state.system["map_tile_url"], + state.system["map_attribution"], + ) + await write_audit(conn, SYSTEM_UPDATE, operator_id=operator_id, target="system", + before={"map_tile_url": old_sys["map_tile_url"], "map_attribution": old_sys["map_attribution"]} if old_sys else None, + after={"map_tile_url": state.system["map_tile_url"], "map_attribution": state.system["map_attribution"]}) + + # 4. INSERT each API key + for key in state.api_keys: + encrypted = base64.b64decode(key["encrypted_value_b64"]) + await conn.execute("INSERT INTO config.api_keys (alias, encrypted_value) VALUES ($1, $2)", key["alias"], encrypted) + await write_audit(conn, API_KEY_CREATE, operator_id=operator_id, target=key["alias"]) + + # 5. UPDATE config.adapters + for name, adapter_cfg in state.adapters.items(): + old_adapter = await conn.fetchrow("SELECT enabled, cadence_s, settings FROM config.adapters WHERE name = $1", name) + await conn.execute( + "UPDATE config.adapters SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() WHERE name = $4", + adapter_cfg["enabled"], adapter_cfg["cadence_s"], adapter_cfg["settings"], name) + await write_audit(conn, ADAPTER_UPDATE, operator_id=operator_id, target=name, + before={"enabled": old_adapter["enabled"], "cadence_s": old_adapter["cadence_s"]} if old_adapter else None, + after={"enabled": adapter_cfg["enabled"], "cadence_s": adapter_cfg["cadence_s"]}) + + await write_audit(conn, SETUP_COMPLETE, operator_id=operator_id, target="system") + + except UniqueViolationError: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse(request=request, name="setup_finish.html", + context={"csrf_token": csrf_token, "operator_count": 1, "key_count": len(state.api_keys), + "system": state.system, "adapters": [{"name": n, "enabled": a["enabled"], "cadence_s": a["cadence_s"]} for n, a in state.adapters.items()], + "error": f"Username '{state.operator['username']}' already exists."}, status_code=200) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + response = RedirectResponse(url="/", status_code=302) + clear_wizard_cookie(response) + unset_pre_auth_csrf_cookie(response) + _set_session_cookie(response, token, lifetime_days * 86400) + return response @router.get("/login", response_class=HTMLResponse) async def login_form( request: Request, diff --git a/src/central/gui/templates/adapters_edit.html b/src/central/gui/templates/adapters_edit.html index 35ae6c8..939aa75 100644 --- a/src/central/gui/templates/adapters_edit.html +++ b/src/central/gui/templates/adapters_edit.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html index 9ff2f50..de3f8c2 100644 --- a/src/central/gui/templates/setup_adapters.html +++ b/src/central/gui/templates/setup_adapters.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html index 36932b4..f4e9277 100644 --- a/src/central/gui/templates/setup_operator.html +++ b/src/central/gui/templates/setup_operator.html @@ -7,17 +7,6 @@ {% include "_wizard_header.html" %} {% endwith %} -{% if existing_operator %} -
-
-

Operator Already Configured

-
-

The operator account {{ existing_operator.username }} has been created.

-
- Next → -
-
-{% else %}

Create Operator Account

@@ -53,5 +42,4 @@
-{% endif %} {% endblock %} diff --git a/src/central/gui/wizard.py b/src/central/gui/wizard.py new file mode 100644 index 0000000..8f28ac3 --- /dev/null +++ b/src/central/gui/wizard.py @@ -0,0 +1,131 @@ +"""Wizard state management for deferred-commit setup flow. + +The wizard collects configuration across 5 steps and commits everything +atomically at the final step. State is carried in a signed cookie. +""" + +import base64 +from dataclasses import dataclass, field, asdict +from typing import Any + +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from starlette.requests import Request +from starlette.responses import Response + + +# 1 hour max age for wizard cookie +WIZARD_MAX_AGE = 3600 +WIZARD_COOKIE = "central_wizard" + + +@dataclass +class WizardOperator: + """Operator data collected in step 1.""" + username: str + password_hash: str + + +@dataclass +class WizardSystem: + """System settings collected in step 2.""" + map_tile_url: str + map_attribution: str + + +@dataclass +class WizardApiKey: + """API key collected in step 3.""" + alias: str + encrypted_value_b64: str # base64-encoded encrypted value + + +@dataclass +class WizardAdapter: + """Adapter config collected in step 4.""" + enabled: bool + cadence_s: int + settings: dict[str, Any] + + +@dataclass +class WizardState: + """Complete wizard state carried across all steps.""" + wizard_step: int = 1 + operator: dict | None = None + system: dict | None = None + api_keys: list[dict] = field(default_factory=list) + adapters: dict[str, dict] | None = None + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "wizard_step": self.wizard_step, + "operator": self.operator, + "system": self.system, + "api_keys": self.api_keys, + "adapters": self.adapters, + } + + @classmethod + def from_dict(cls, data: dict) -> "WizardState": + """Create from dictionary.""" + return cls( + wizard_step=data.get("wizard_step", 1), + operator=data.get("operator"), + system=data.get("system"), + api_keys=data.get("api_keys", []), + adapters=data.get("adapters"), + ) + + +def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer: + """Get a timed serializer for wizard state.""" + return URLSafeTimedSerializer(secret_key, salt="wizard-state") + + +def get_wizard_state(request: Request, secret_key: str) -> WizardState | None: + """Decode wizard state from cookie. + + Returns WizardState if valid, None if missing/invalid/expired. + """ + cookie_value = request.cookies.get(WIZARD_COOKIE) + if not cookie_value: + return None + + serializer = _get_wizard_serializer(secret_key) + try: + data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE) + return WizardState.from_dict(data) + except (BadSignature, SignatureExpired): + return None + + +def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None: + """Set the wizard state cookie on a response.""" + serializer = _get_wizard_serializer(secret_key) + signed_value = serializer.dumps(state.to_dict()) + response.set_cookie( + WIZARD_COOKIE, + signed_value, + max_age=WIZARD_MAX_AGE, + path="/setup", + httponly=True, + samesite="lax", + ) + + +def clear_wizard_cookie(response: Response) -> None: + """Remove the wizard state cookie.""" + response.delete_cookie(WIZARD_COOKIE, path="/setup") + + +def get_step_route(step: int) -> str: + """Get the route for a wizard step number.""" + routes = { + 1: "/setup/operator", + 2: "/setup/system", + 3: "/setup/keys", + 4: "/setup/adapters", + 5: "/setup/finish", + } + return routes.get(step, "/setup/operator") diff --git a/tests/test_wizard.py b/tests/test_wizard.py index dcaa7fe..0492276 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,4 +1,4 @@ -"""Tests for the first-run setup wizard.""" +"""Tests for the first-run setup wizard with deferred-commit pattern.""" from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch @@ -11,60 +11,38 @@ from central.gui.routes import ( setup_system_submit, setup_keys_form, setup_keys_submit, - setup_adapters_form, - setup_adapters_submit, setup_finish_form, setup_finish_submit, ) -from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step +from central.gui.middleware import SetupGateMiddleware +from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie class TestWizardStepRedirect: - """Test wizard step redirect logic.""" + """Test wizard step redirect logic based on cookie state.""" - @pytest.mark.asyncio - async def test_no_operators_redirects_to_operator(self): - """When no operators exist, redirect to /setup/operator.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [0] # No operators + def test_no_cookie_redirects_to_operator(self): + """When no wizard cookie exists, redirect to /setup/operator.""" + from central.gui.middleware import _get_wizard_redirect_from_cookie - result = await _get_wizard_redirect_step(mock_conn) + mock_request = MagicMock() + mock_request.cookies = {} + + result = _get_wizard_redirect_from_cookie(mock_request, "testsecret") assert result == "/setup/operator" - @pytest.mark.asyncio - async def test_default_tile_url_redirects_to_system(self): - """When map_tile_url is default, redirect to /setup/system.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1] # Has operator - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - } + def test_cookie_step_2_redirects_to_system(self): + """When wizard_step=2 in cookie, redirect to /setup/system.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(2) assert result == "/setup/system" - @pytest.mark.asyncio - async def test_no_adapters_touched_redirects_to_keys(self): - """When no adapters have been updated, redirect to /setup/keys.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } + def test_cookie_step_5_redirects_to_finish(self): + """When wizard_step=5 in cookie, redirect to /setup/finish.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) - assert result == "/setup/keys" - - @pytest.mark.asyncio - async def test_all_steps_complete_redirects_to_finish(self): - """When all steps done, redirect to /setup/finish.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } - - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(5) assert result == "/setup/finish" @@ -72,63 +50,26 @@ class TestSetupOperatorForm: """Test operator creation form (step 1).""" @pytest.mark.asyncio - async def test_get_returns_form(self): - """GET /setup/operator returns the form when no operator exists.""" + async def test_get_returns_form_without_prefill(self): + """GET /setup/operator returns the form when no wizard cookie exists.""" mock_request = MagicMock() + mock_request.cookies = {} mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = None # No operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert "csrf_token" in context and context["csrf_token"] assert context["error"] is None - assert context["existing_operator"] is None - - @pytest.mark.asyncio - async def test_get_returns_confirmation_when_operator_exists(self): - """GET /setup/operator shows confirmation when operator already exists.""" - mock_request = MagicMock() - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.body = b"Operator Already Configured" - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "admin"} - assert context["error"] is None + assert context["form_data"] is None class TestSetupOperatorSubmit: @@ -138,28 +79,17 @@ class TestSetupOperatorSubmit: async def test_password_mismatch_shows_error(self): """POST with password mismatch re-renders with error.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password1", - "confirm_password": "password2", # Mismatch - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) + mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")): result = await setup_operator_submit( mock_request, username="testuser", @@ -172,374 +102,43 @@ class TestSetupOperatorSubmit: assert context["error"] == "Passwords do not match" @pytest.mark.asyncio - async def test_valid_creates_operator_and_redirects(self): - """POST with valid data creates operator and redirects to /setup/system.""" + async def test_valid_creates_wizard_cookie_and_redirects(self): + """POST with valid data creates wizard cookie and redirects to /setup/system.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - mock_conn.fetchrow.side_effect = [ - {"id": 1}, # INSERT RETURNING id - {"session_lifetime_days": 90}, # system settings - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.hash_password", return_value="hashed"): - with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: - mock_session.return_value = ("session_token", datetime.now(), "csrf_token") - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.hash_password", return_value="hashed_pw"): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) assert result.status_code == 302 assert result.headers["location"] == "/setup/system" - @pytest.mark.asyncio - async def test_post_when_operator_exists_shows_confirmation(self): - """POST when operator exists returns 200 with confirmation, no insert.""" - mock_request = MagicMock() - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.status_code = 200 - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 1 # Operator already exists - mock_conn.fetchrow.return_value = {"username": "existing_admin"} - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - mock_request.state.csrf_token = "test_csrf" - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) - - # Should return 200, not 500 or redirect - assert result.status_code == 200 - - # Should render confirmation state - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "existing_admin"} - - # Should NOT call write_audit (no insert happened) - mock_audit.assert_not_called() - class TestSetupSystemForm: """Test system settings form (step 2).""" @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/system without auth redirects to /setup/operator.""" + async def test_no_wizard_cookie_redirects_to_operator(self): + """GET /setup/system without wizard cookie redirects to /setup/operator.""" mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_system_form(mock_request) + mock_request.cookies = {} + + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + result = await setup_system_form(mock_request) + assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" - @pytest.mark.asyncio - async def test_authenticated_returns_form(self): - """GET /setup/system with auth returns the form.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": "© OpenStreetMap contributors", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - - -class TestSetupSystemSubmit: - """Test system settings submission.""" - - @pytest.mark.asyncio - async def test_missing_placeholders_shows_error(self): - """POST without {z},{x},{y} placeholders shows error.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/tiles", - "map_attribution": "Test", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "", - "map_attribution": "", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert "map_tile_url" in context["errors"] - - @pytest.mark.asyncio - async def test_valid_updates_and_redirects(self): - """POST with valid data updates system and redirects to /setup/keys.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/{z}/{x}/{y}.png", - "map_attribution": "Test Attribution", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "old_url", - "map_attribution": "old_attr", - } - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_system_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/setup/keys" - - -class TestSetupKeysForm: - """Test API keys form (step 3).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/keys without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_keys_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupKeysSubmit: - """Test API keys submission.""" - - @pytest.mark.asyncio - async def test_next_action_redirects_to_adapters(self): - """POST with action=next redirects to /setup/adapters.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "next", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - # No need to mock get_pool since action="next" returns before it's called - result = await setup_keys_submit(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/adapters" - - @pytest.mark.asyncio - async def test_add_key_creates_and_rerenders(self): - """POST with action=add creates key and re-renders with success.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "add", - "alias": "testkey", - "plaintext_key": "secret123", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.side_effect = [ - None, # No existing key - {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, - ] - mock_conn.fetch.side_effect = [ - [], # First list - [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.crypto.encrypt", return_value=b"encrypted"): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_keys_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["success"] == "API key 'testkey' added successfully." - - -class TestSetupAdaptersForm: - """Test adapters configuration form (step 4).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/adapters without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_adapters_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupFinishForm: - """Test finish page (step 5).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/finish without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_finish_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_authenticated_shows_summary(self): - """GET /setup/finish with auth shows summary.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys - mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} - mock_conn.fetch.return_value = [ - {"name": "nws", "enabled": True, "cadence_s": 300}, - {"name": "firms", "enabled": False, "cadence_s": 600}, - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_finish_form(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["operator_count"] == 1 - assert context["key_count"] == 2 - assert len(context["adapters"]) == 2 - - -class TestSetupFinishSubmit: - """Test setup completion.""" - - @pytest.mark.asyncio - async def test_marks_setup_complete_and_redirects(self): - """POST /setup/finish marks setup_complete=true and redirects to /.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - # Mock form with CSRF token - form_data = MagicMock() - form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_finish_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/" - mock_conn.execute.assert_called_once() - mock_audit.assert_called_once() - class TestSetupGateMiddlewareWizard: """Test SetupGateMiddleware with wizard paths.""" @@ -570,69 +169,6 @@ class TestSetupGateMiddlewareWizard: response = client.get("/setup/operator") assert response.status_code == 200 - @pytest.mark.asyncio - async def test_redirects_base_setup_to_wizard_step(self): - """SetupGateMiddleware redirects /setup to appropriate wizard step.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.fetchval = AsyncMock(return_value=0) # No operators - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/setup") - async def setup(): - return {"message": "base setup"} - - @app.get("/setup/operator") - async def setup_operator(): - return {"message": "operator"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/setup") - assert response.status_code == 302 - assert response.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_redirects_login_to_setup_when_incomplete(self): - """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/login") - async def login(): - return {"message": "login"} - - @app.get("/setup") - async def setup(): - return {"message": "setup"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/login") - assert response.status_code == 302 - assert response.headers["location"] == "/setup" - @pytest.mark.asyncio async def test_redirects_all_setup_paths_when_complete(self): """SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""