diff --git a/src/central/gui/csrf.py b/src/central/gui/csrf.py index 37848cd..0d6198f 100644 --- a/src/central/gui/csrf.py +++ b/src/central/gui/csrf.py @@ -1,10 +1,11 @@ -"""Pre-auth CSRF protection for login and setup pages. +"""Pre-auth CSRF protection for login and setup/operator pages. These routes cannot use session-bound CSRF because no session exists yet. Uses a simple cookie-based pattern with short-lived tokens. """ import secrets +from typing import Optional from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from starlette.requests import Request @@ -33,34 +34,6 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]: return plain_token, signed_token -def reuse_or_generate_pre_auth_csrf( - request: Request, - secret_key: str, -) -> tuple[str, str | None]: - """Reuse an existing valid pre-auth CSRF token, or generate new. - - Returns (plain_token, signed_token_for_cookie). - If signed_token_for_cookie is None, the existing cookie is - still valid and caller should not call set_pre_auth_csrf_cookie. - If non-None, caller MUST call set_pre_auth_csrf_cookie with - it to persist the new value. - """ - cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE) - if cookie_value: - serializer = _get_serializer(secret_key) - try: - plain_token = serializer.loads( - cookie_value, - max_age=PRE_AUTH_CSRF_MAX_AGE, - ) - return plain_token, None # reuse existing - except (BadSignature, SignatureExpired): - pass # fall through to generate - - plain_token, signed_token = generate_pre_auth_csrf(secret_key) - return plain_token, signed_token - - def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None: """Set the pre-auth CSRF cookie on a response.""" response.set_cookie( diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index 2af6230..155112b 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -16,15 +16,7 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} -AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/") - -# Browser-noise paths that trigger CSRF race conditions -BROWSER_NOISE_PATHS = { - "/favicon.ico", - "/apple-touch-icon.png", - "/apple-touch-icon-precomposed.png", - "/robots.txt", -} +AUTH_EXEMPT_PREFIXES = ("/static/",) def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: @@ -37,14 +29,33 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False -def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str: - """Determine wizard redirect step from cookie state.""" - from central.gui.wizard import get_wizard_state, get_step_route - - state = get_wizard_state(request, csrf_secret) - if state is None: +async def _get_wizard_redirect_step(conn) -> str: + """Determine which wizard step to redirect to based on DB state.""" + # Check if any operators exist + op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + if op_count == 0: return "/setup/operator" - return get_step_route(state.wizard_step) + + # Check if system settings have been configured (map_tile_url not default) + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + if sys_row is None or sys_row["map_tile_url"] == default_tile: + return "/setup/system" + + # Keys step is optional, so check adapters have been reviewed + # We consider adapters reviewed if any adapter has a non-null updated_at + # (meaning it was explicitly saved during setup) + adapters_touched = await conn.fetchval( + "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" + ) + if adapters_touched == 0: + # Go to keys first, then adapters + return "/setup/keys" + + # All steps done, go to finish + return "/setup/finish" class SetupGateMiddleware(BaseHTTPMiddleware): @@ -53,10 +64,6 @@ class SetupGateMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path - # Short-circuit browser-noise requests that cause CSRF races - if path in BROWSER_NOISE_PATHS: - return Response(status_code=204) - # Check setup status from database pool = get_pool() if pool is None: @@ -78,16 +85,13 @@ class SetupGateMiddleware(BaseHTTPMiddleware): if not setup_complete: # Setup not complete - only allow setup paths and static/health if path.startswith("/setup"): - # Allow all /setup/* paths + # Allow all /setup/* paths (handler will enforce auth) # But /setup with no subpath should redirect to appropriate step if path == "/setup" or path == "/setup/": try: - from central.bootstrap_config import get_settings - settings = get_settings() - redirect_step = _get_wizard_redirect_from_cookie( - request, settings.csrf_secret - ) - return RedirectResponse(url=redirect_step, status_code=302) + async with pool.acquire() as conn: + redirect_step = await _get_wizard_redirect_step(conn) + return RedirectResponse(url=redirect_step, status_code=302) except Exception: logger.warning("Failed to determine wizard step", exc_info=True) return RedirectResponse(url="/setup/operator", status_code=302) @@ -114,11 +118,6 @@ class SessionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path - # Short-circuit browser-noise requests (already handled by SetupGateMiddleware, - # but this protects if middleware order changes) - if path in BROWSER_NOISE_PATHS: - return Response(status_code=204) - # Initialize state request.state.operator = None request.state.csrf_token = None @@ -140,7 +139,7 @@ class SessionMiddleware(BaseHTTPMiddleware): request.state.operator = None request.state.csrf_token = None - # Check if auth is required - setup paths are exempt during wizard + # Check if auth is required if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): if request.state.operator is None: return RedirectResponse(url="/login", status_code=302) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 9e4f2d0..1315ff3 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -13,7 +13,6 @@ from fastapi import APIRouter, Depends, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response from central.bootstrap_config import get_settings from central.gui.csrf import ( - reuse_or_generate_pre_auth_csrf, generate_pre_auth_csrf, set_pre_auth_csrf_cookie, validate_pre_auth_csrf, @@ -50,9 +49,6 @@ router = APIRouter() # Streams to display on dashboard DASHBOARD_STREAMS = ["CENTRAL_WX", "CENTRAL_FIRE", "CENTRAL_QUAKE", "CENTRAL_META"] -# Email validation regex (simple but effective) -ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$") - # Email validation regex (simple but effective) EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") @@ -271,29 +267,24 @@ async def dashboard_polls(request: Request) -> HTMLResponse: # ============================================================================= -# ============================================================================= -# Setup Wizard routes (deferred-commit pattern) -# ============================================================================= - - @router.get("/setup/operator", response_class=HTMLResponse) -async def setup_operator_form(request: Request) -> HTMLResponse: +async def setup_operator_form( + request: Request, +) -> HTMLResponse: """Render the setup operator form (step 1).""" - from central.gui.wizard import get_wizard_state - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - templates = _get_templates() + pool = get_pool() settings = get_settings() + csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) - # Get wizard state from cookie (if any) - state = get_wizard_state(request, settings.csrf_secret) - - # Pre-fill from cookie state if available - form_data = None - if state and state.operator: - form_data = {"username": state.operator.get("username", "")} - - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + # Check if operator already exists + existing_operator = None + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT username FROM config.operators ORDER BY id LIMIT 1" + ) + if row: + existing_operator = {"username": row["username"]} response = templates.TemplateResponse( request=request, @@ -301,11 +292,11 @@ async def setup_operator_form(request: Request) -> HTMLResponse: context={ "csrf_token": csrf_token, "error": None, - "form_data": form_data, + "form_data": None, + "existing_operator": existing_operator, }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) + set_pre_auth_csrf_cookie(response, signed_token) return response @@ -315,22 +306,39 @@ async def setup_operator_submit( username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), + ) -> Response: """Process the setup operator form (step 1).""" - from central.gui.wizard import get_wizard_state, set_wizard_cookie, WizardState - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - templates = _get_templates() - settings = get_settings() + pool = get_pool() # Validate CSRF + settings = get_settings() form = await request.form() form_csrf = form.get("csrf_token", "") if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - # Get or create wizard state - state = get_wizard_state(request, settings.csrf_secret) or WizardState() + # Check if operator already exists (single-operator-per-install design) + async with pool.acquire() as conn: + count = await conn.fetchval("SELECT count(*) FROM config.operators") + if count > 0: + # Operator already exists — render confirmation page + existing = await conn.fetchrow( + "SELECT username FROM config.operators ORDER BY id LIMIT 1" + ) + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_operator.html", + context={ + "csrf_token": csrf_token, + "error": None, + "form_data": None, + "existing_operator": {"username": existing["username"]}, + }, + ) + return response # Validate input error = None @@ -343,7 +351,7 @@ async def setup_operator_submit( error = str(e) if error: - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_operator.html", @@ -351,54 +359,73 @@ async def setup_operator_submit( "csrf_token": csrf_token, "error": error, "form_data": {"username": username}, + "existing_operator": None, }, status_code=200, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) return response - # Hash password and store in wizard state (NO DB write) + # Create operator password_hash = hash_password(password) - state.operator = {"username": username, "password_hash": password_hash} - state.wizard_step = max(state.wizard_step, 2) + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO config.operators (username, password_hash) + VALUES ($1, $2) + RETURNING id + """, + username, + password_hash, + ) + operator_id = row["id"] - # Redirect to next step with updated wizard cookie + # Write audit log + await write_audit( + conn, + OPERATOR_CREATE, + operator_id=operator_id, + target=username, + ) + + # Get session lifetime + sysrow = await conn.fetchrow( + "SELECT session_lifetime_days FROM config.system WHERE id = true" + ) + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + + # Create session + token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) + + # Redirect to next step with session cookie response = RedirectResponse(url="/setup/system", status_code=302) - set_wizard_cookie(response, state, settings.csrf_secret) + _set_session_cookie(response, token, lifetime_days * 86400) return response @router.get("/setup/system", response_class=HTMLResponse) -async def setup_system_form(request: Request) -> HTMLResponse: +async def setup_system_form( + request: Request, + +) -> HTMLResponse: """Render the system settings form (step 2).""" - from central.gui.wizard import get_wizard_state - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - - settings = get_settings() - - # Get wizard state - required for step 2+ - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - # Pre-fill from cookie state or DB defaults - if state.system: - system = state.system - else: - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", - } + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_system.html", @@ -410,31 +437,29 @@ async def setup_system_form(request: Request) -> HTMLResponse: "system": system, }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/system") -async def setup_system_submit(request: Request) -> Response: +async def setup_system_submit( + request: Request, + +) -> Response: """Process the system settings form (step 2).""" - from central.gui.wizard import get_wizard_state, set_wizard_cookie - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - - templates = _get_templates() - settings = get_settings() - - # Get wizard state - required - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) - # Validate CSRF + templates = _get_templates() + pool = get_pool() + form = await request.form() form_csrf = form.get("csrf_token", "") - if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + if not form_csrf or form_csrf != request.state.csrf_token: raise CsrfValidationError("Invalid CSRF token") + form = await request.form() map_tile_url = form.get("map_tile_url", "").strip() map_attribution = form.get("map_attribution", "").strip() @@ -455,52 +480,87 @@ async def setup_system_submit(request: Request) -> Response: if not map_attribution: errors["map_attribution"] = "Map attribution is required" - if errors: - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) - response = templates.TemplateResponse( - request=request, - name="setup_system.html", - context={ - "csrf_token": csrf_token, - "error": None, - "errors": errors, - "form_data": form_data, - "system": state.system or form_data, - }, - status_code=200, + async with pool.acquire() as conn: + if errors: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "", + "map_attribution": row["map_attribution"] if row else "", + } + + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": system, + }, + status_code=200, + ) + return response + + # Get current values for audit + old_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) - return response + before = { + "map_tile_url": old_row["map_tile_url"] if old_row else None, + "map_attribution": old_row["map_attribution"] if old_row else None, + } - # Update wizard state (NO DB write) - state.system = {"map_tile_url": map_tile_url, "map_attribution": map_attribution} - state.wizard_step = max(state.wizard_step, 3) + # Update system settings + await conn.execute( + """ + UPDATE config.system + SET map_tile_url = $1, map_attribution = $2 + WHERE id = true + """, + map_tile_url, + map_attribution, + ) - response = RedirectResponse(url="/setup/keys", status_code=302) - set_wizard_cookie(response, state, settings.csrf_secret) - return response + # Write audit log + await write_audit( + conn, + SYSTEM_UPDATE, + operator_id=operator.id, + target="system", + before=before, + after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, + ) + + return RedirectResponse(url="/setup/keys", status_code=302) @router.get("/setup/keys", response_class=HTMLResponse) -async def setup_keys_form(request: Request) -> HTMLResponse: +async def setup_keys_form( + request: Request, + +) -> HTMLResponse: """Render the API keys form (step 3).""" - from central.gui.wizard import get_wizard_state - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - - settings = get_settings() - - # Get wizard state - required - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) + from central.crypto import encrypt + templates = _get_templates() + pool = get_pool() - # Keys come from cookie state (not DB) - keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -512,40 +572,36 @@ async def setup_keys_form(request: Request) -> HTMLResponse: "success": None, }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/keys") -async def setup_keys_submit(request: Request) -> Response: +async def setup_keys_submit( + request: Request, + +) -> Response: """Process the API keys form (step 3).""" - from central.gui.wizard import get_wizard_state, set_wizard_cookie - from central.gui.csrf import reuse_or_generate_pre_auth_csrf + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not form_csrf or form_csrf != request.state.csrf_token: + raise CsrfValidationError("Invalid CSRF token") + + form = await request.form() + action = form.get("action", "add") + + # If action is "next", redirect to adapters step + if action == "next": + return RedirectResponse(url="/setup/adapters", status_code=302) + from central.crypto import encrypt templates = _get_templates() - settings = get_settings() - - # Get wizard state - required - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) - - # Validate CSRF - form = await request.form() - form_csrf = form.get("csrf_token", "") - if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): - raise CsrfValidationError("Invalid CSRF token") - - action = form.get("action", "add") - - # If action is "next", advance to adapters step - if action == "next": - state.wizard_step = max(state.wizard_step, 4) - response = RedirectResponse(url="/setup/adapters", status_code=302) - set_wizard_cookie(response, state, settings.csrf_secret) - return response + pool = get_pool() # Otherwise, add a new key alias = form.get("alias", "").strip() @@ -561,8 +617,6 @@ async def setup_keys_submit(request: Request) -> Response: errors["alias"] = "Alias must be at most 64 characters" elif not ALIAS_REGEX.match(alias): errors["alias"] = "Alias must contain only letters, numbers, and underscores" - elif any(k["alias"] == alias for k in state.api_keys): - errors["alias"] = "An API key with this alias already exists" # Validate plaintext_key if not plaintext_key: @@ -570,34 +624,69 @@ async def setup_keys_submit(request: Request) -> Response: elif len(plaintext_key) > 4096: errors["plaintext_key"] = "API key must be at most 4096 characters" - keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] + async with pool.acquire() as conn: + if not errors: + # Check if alias already exists + existing = await conn.fetchrow( + "SELECT alias FROM config.api_keys WHERE alias = $1", + alias, + ) + if existing: + errors["alias"] = "An API key with this alias already exists" - if errors: - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) - response = templates.TemplateResponse( - request=request, - name="setup_keys.html", - context={ - "csrf_token": csrf_token, - "keys": keys, - "errors": errors, - "form_data": form_data, - "success": None, - }, - status_code=200, + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) - return response + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] - # Encrypt the key and add to state (NO DB write) - encrypted_value = encrypt(plaintext_key.encode()) - encrypted_b64 = base64.b64encode(encrypted_value).decode() - state.api_keys.append({"alias": alias, "encrypted_value_b64": encrypted_b64}) + if errors: + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, + ) + return response + + # Encrypt the key + encrypted_value = encrypt(plaintext_key.encode()) + + # Insert the new key + row = await conn.fetchrow( + """ + INSERT INTO config.api_keys (alias, encrypted_value) + VALUES ($1, $2) + RETURNING created_at + """, + alias, + encrypted_value, + ) + + # Write audit log (no plaintext!) + await write_audit( + conn, + API_KEY_CREATE, + operator_id=operator.id, + target=alias, + before=None, + after={"alias": alias, "created_at": row["created_at"].isoformat()}, + ) + + # Refresh keys list + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] # Re-render with success message - keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -609,82 +698,61 @@ async def setup_keys_submit(request: Request) -> Response: "success": f"API key '{alias}' added successfully.", }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) - set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/adapters", response_class=HTMLResponse) -async def setup_adapters_form(request: Request) -> HTMLResponse: +async def setup_adapters_form( + request: Request, + +) -> HTMLResponse: """Render the adapters configuration form (step 4).""" - from central.gui.wizard import get_wizard_state - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - - settings = get_settings() - - # Get wizard state - required - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - # Pre-fill from cookie state or DB defaults - if state.adapters: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) adapters = [] - for name in ["firms", "nws", "usgs_quake"]: - if name in state.adapters: - a = state.adapters[name] - adapters.append({ - "name": name, - "enabled": a["enabled"], - "cadence_s": a["cadence_s"], - "settings": a["settings"], - }) - else: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) - adapters = [] - for row in rows: - settings_data = row["settings"] or {} - adapters.append({ - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - "settings": settings_data, - }) + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) - # Get API keys from wizard state (not DB) - api_keys = [{"alias": k["alias"]} for k in state.api_keys] + # Get API keys for dropdown + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) - # Get map tile settings from wizard state or DB - if state.system: - tile_url = state.system["map_tile_url"] - tile_attribution = state.system["map_attribution"] - else: - async with pool.acquire() as conn: - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Get map tile settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_adapters.html", context={ "csrf_token": csrf_token, "adapters": adapters, - "api_keys": api_keys, + "api_keys": [{"alias": k["alias"]} for k in api_keys], "valid_satellites": _get_valid_satellites(), "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, @@ -694,216 +762,242 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: "form_data": None, }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/adapters") -async def setup_adapters_submit(request: Request) -> Response: +async def setup_adapters_submit( + request: Request, + +) -> Response: """Process the adapters configuration form (step 4).""" - from central.gui.wizard import get_wizard_state, set_wizard_cookie - from central.gui.csrf import reuse_or_generate_pre_auth_csrf + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - settings = get_settings() - # Get wizard state - required - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) - - # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + if not form_csrf or form_csrf != request.state.csrf_token: raise CsrfValidationError("Invalid CSRF token") + form = await request.form() errors: dict[str, str] = {} - new_adapters: dict[str, dict] = {} - # Get current adapter configs from state or DB as baseline - if state.adapters: - current_adapters = state.adapters - else: - async with pool.acquire() as conn: + async with pool.acquire() as conn: + # Get current adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + + for row in rows: + adapter_name = row["name"] + current_settings = row["settings"] or {} + new_settings = dict(current_settings) + + # Parse enabled + enabled = f"{adapter_name}_enabled" in form + + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = row["cadence_s"] + + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + key_exists = await conn.fetchrow( + "SELECT 1 FROM config.api_keys WHERE alias = $1", + api_key_alias, + ) + if not key_exists: + errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = f"Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + # Store parsed data for re-render on error or update + if not errors.get(f"{adapter_name}_cadence_s"): + # Update adapter + await conn.execute( + """ + UPDATE config.adapters + SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() + WHERE name = $4 + """, + enabled, + cadence_s, + new_settings, + adapter_name, + ) + + # If any errors, re-render + if errors: + adapters = [] rows = await conn.fetch( - "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ ) - current_adapters = {} for row in rows: - current_adapters[row["name"]] = { + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"], - "settings": row["settings"] or {}, - } + "settings": settings, + }) - for adapter_name in ["firms", "nws", "usgs_quake"]: - current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) - current_settings = current.get("settings", {}) - new_settings = dict(current_settings) + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) - # Parse enabled - enabled = f"{adapter_name}_enabled" in form + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - # Parse cadence - cadence_str = form.get(f"{adapter_name}_cadence_s", "") - try: - cadence_s = int(cadence_str) - if cadence_s < 60 or cadence_s > 3600: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" - except ValueError: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" - cadence_s = current.get("cadence_s", 300) + csrf_token = request.state.csrf_token + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": [{"alias": k["alias"]} for k in api_keys], + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + return response - # Adapter-specific validation - if adapter_name == "nws": - contact_email = form.get(f"{adapter_name}_contact_email", "").strip() - if enabled: - if not contact_email: - errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(contact_email): - errors[f"{adapter_name}_contact_email"] = "Invalid email format" - else: - new_settings["contact_email"] = contact_email - else: - new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + return RedirectResponse(url="/setup/finish", status_code=302) - elif adapter_name == "firms": - api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() - satellites = form.getlist(f"{adapter_name}_satellites") - - if api_key_alias: - # Validate against wizard state keys - if not any(k["alias"] == api_key_alias for k in state.api_keys): - errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" - else: - new_settings["api_key_alias"] = api_key_alias - else: - new_settings["api_key_alias"] = None - - # Validate satellites - valid_sats = set(_get_valid_satellites()) - invalid_sats = [s for s in satellites if s not in valid_sats] - if invalid_sats: - errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) - else: - new_settings["satellites"] = satellites - - elif adapter_name == "usgs_quake": - feed = form.get(f"{adapter_name}_feed", "").strip() - valid_feeds = _get_valid_feeds() - if feed not in valid_feeds: - errors[f"{adapter_name}_feed"] = "Invalid feed" - else: - new_settings["feed"] = feed - - # Region validation (all adapters) - region_north_str = form.get(f"{adapter_name}_region_north", "").strip() - region_south_str = form.get(f"{adapter_name}_region_south", "").strip() - region_east_str = form.get(f"{adapter_name}_region_east", "").strip() - region_west_str = form.get(f"{adapter_name}_region_west", "").strip() - - try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) - - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" - else: - new_settings["region"] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" - - new_adapters[adapter_name] = { - "enabled": enabled, - "cadence_s": cadence_s, - "settings": new_settings, - } - - # If errors, re-render - if errors: - adapters = [ - {"name": name, "enabled": new_adapters[name]["enabled"], - "cadence_s": new_adapters[name]["cadence_s"], - "settings": new_adapters[name]["settings"]} - for name in ["firms", "nws", "usgs_quake"] - ] - api_keys = [{"alias": k["alias"]} for k in state.api_keys] - - if state.system: - tile_url = state.system["map_tile_url"] - tile_attribution = state.system["map_attribution"] - else: - tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = "© OpenStreetMap contributors" - - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) - response = templates.TemplateResponse( - request=request, - name="setup_adapters.html", - context={ - "csrf_token": csrf_token, - "adapters": adapters, - "api_keys": api_keys, - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), - "tile_url": tile_url, - "tile_attribution": tile_attribution, - "error": "Please fix the errors below.", - "errors": errors, - "form_data": form, - }, - status_code=200, - ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) - return response - - # Update wizard state (NO DB write) - state.adapters = new_adapters - state.wizard_step = max(state.wizard_step, 5) - - response = RedirectResponse(url="/setup/finish", status_code=302) - set_wizard_cookie(response, state, settings.csrf_secret) - return response @router.get("/setup/finish", response_class=HTMLResponse) -async def setup_finish_form(request: Request) -> HTMLResponse: +async def setup_finish_form( + request: Request, + +) -> HTMLResponse: """Render the finish setup page (step 5).""" - from central.gui.wizard import get_wizard_state - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - - settings = get_settings() - - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() + pool = get_pool() - operator_count = 1 if state.operator else 0 - key_count = len(state.api_keys) - system = state.system or {"map_tile_url": "(not configured)"} + async with pool.acquire() as conn: + # Get counts + operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") - adapters = [] - if state.adapters: - for name in ["firms", "nws", "usgs_quake"]: - if name in state.adapters: - a = state.adapters[name] - adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) + # Get system settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": sys_row["map_tile_url"] if sys_row else "", + } - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + # Get adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s + FROM config.adapters + ORDER BY name + """ + ) + adapters = [ + { + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + } + for row in rows + ] + + csrf_token = request.state.csrf_token response = templates.TemplateResponse( request=request, name="setup_finish.html", @@ -913,101 +1007,46 @@ async def setup_finish_form(request: Request) -> HTMLResponse: "key_count": key_count, "system": system, "adapters": adapters, - "error": None, }, ) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/finish") -async def setup_finish_submit(request: Request) -> Response: - """Complete the setup wizard - atomic commit of all wizard state.""" - from central.gui.wizard import get_wizard_state, clear_wizard_cookie - from central.gui.csrf import reuse_or_generate_pre_auth_csrf - from asyncpg.exceptions import UniqueViolationError +async def setup_finish_submit( + request: Request, - templates = _get_templates() - pool = get_pool() - settings = get_settings() - - state = get_wizard_state(request, settings.csrf_secret) - if state is None or state.operator is None: +) -> Response: + """Complete the setup wizard.""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: return RedirectResponse(url="/setup/operator", status_code=302) + pool = get_pool() + form = await request.form() form_csrf = form.get("csrf_token", "") - if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + if not form_csrf or form_csrf != request.state.csrf_token: raise CsrfValidationError("Invalid CSRF token") - if not state.system: - return RedirectResponse(url="/setup/system", status_code=302) - if not state.adapters: - return RedirectResponse(url="/setup/adapters", status_code=302) + async with pool.acquire() as conn: + # Mark setup complete + await conn.execute( + "UPDATE config.system SET setup_complete = true WHERE id = true" + ) - try: - async with pool.acquire() as conn: - async with conn.transaction(): - # 1. INSERT operator - op_row = await conn.fetchrow( - "INSERT INTO config.operators (username, password_hash) VALUES ($1, $2) RETURNING id", - state.operator["username"], - state.operator["password_hash"], - ) - operator_id = op_row["id"] + # Write audit log + await write_audit( + conn, + SETUP_COMPLETE, + operator_id=operator.id, + target="system", + ) - await write_audit(conn, OPERATOR_CREATE, operator_id=operator_id, target=state.operator["username"]) + return RedirectResponse(url="/", status_code=302) - # 2. Create session - sysrow = await conn.fetchrow("SELECT session_lifetime_days FROM config.system WHERE id = true") - lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 - token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) - # 3. UPDATE config.system - old_sys = await conn.fetchrow("SELECT map_tile_url, map_attribution FROM config.system WHERE id = true") - await conn.execute( - "UPDATE config.system SET map_tile_url = $1, map_attribution = $2, setup_complete = true WHERE id = true", - state.system["map_tile_url"], - state.system["map_attribution"], - ) - await write_audit(conn, SYSTEM_UPDATE, operator_id=operator_id, target="system", - before={"map_tile_url": old_sys["map_tile_url"], "map_attribution": old_sys["map_attribution"]} if old_sys else None, - after={"map_tile_url": state.system["map_tile_url"], "map_attribution": state.system["map_attribution"]}) - - # 4. INSERT each API key - for key in state.api_keys: - encrypted = base64.b64decode(key["encrypted_value_b64"]) - await conn.execute("INSERT INTO config.api_keys (alias, encrypted_value) VALUES ($1, $2)", key["alias"], encrypted) - await write_audit(conn, API_KEY_CREATE, operator_id=operator_id, target=key["alias"]) - - # 5. UPDATE config.adapters - for name, adapter_cfg in state.adapters.items(): - old_adapter = await conn.fetchrow("SELECT enabled, cadence_s, settings FROM config.adapters WHERE name = $1", name) - await conn.execute( - "UPDATE config.adapters SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() WHERE name = $4", - adapter_cfg["enabled"], adapter_cfg["cadence_s"], adapter_cfg["settings"], name) - await write_audit(conn, ADAPTER_UPDATE, operator_id=operator_id, target=name, - before={"enabled": old_adapter["enabled"], "cadence_s": old_adapter["cadence_s"]} if old_adapter else None, - after={"enabled": adapter_cfg["enabled"], "cadence_s": adapter_cfg["cadence_s"]}) - - await write_audit(conn, SETUP_COMPLETE, operator_id=operator_id, target="system") - - except UniqueViolationError: - csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) - response = templates.TemplateResponse(request=request, name="setup_finish.html", - context={"csrf_token": csrf_token, "operator_count": 1, "key_count": len(state.api_keys), - "system": state.system, "adapters": [{"name": n, "enabled": a["enabled"], "cadence_s": a["cadence_s"]} for n, a in state.adapters.items()], - "error": f"Username '{state.operator['username']}' already exists."}, status_code=200) - if signed_token is not None: - set_pre_auth_csrf_cookie(response, signed_token) - return response - - response = RedirectResponse(url="/", status_code=302) - clear_wizard_cookie(response) - unset_pre_auth_csrf_cookie(response) - _set_session_cookie(response, token, lifetime_days * 86400) - return response @router.get("/login", response_class=HTMLResponse) async def login_form( request: Request, diff --git a/src/central/gui/templates/adapters_edit.html b/src/central/gui/templates/adapters_edit.html index 939aa75..35ae6c8 100644 --- a/src/central/gui/templates/adapters_edit.html +++ b/src/central/gui/templates/adapters_edit.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html index de3f8c2..9ff2f50 100644 --- a/src/central/gui/templates/setup_adapters.html +++ b/src/central/gui/templates/setup_adapters.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html index f4e9277..36932b4 100644 --- a/src/central/gui/templates/setup_operator.html +++ b/src/central/gui/templates/setup_operator.html @@ -7,6 +7,17 @@ {% include "_wizard_header.html" %} {% endwith %} +{% if existing_operator %} +
+
+

Operator Already Configured

+
+

The operator account {{ existing_operator.username }} has been created.

+
+ Next → +
+
+{% else %}

Create Operator Account

@@ -42,4 +53,5 @@
+{% endif %} {% endblock %} diff --git a/src/central/gui/wizard.py b/src/central/gui/wizard.py deleted file mode 100644 index 8f28ac3..0000000 --- a/src/central/gui/wizard.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Wizard state management for deferred-commit setup flow. - -The wizard collects configuration across 5 steps and commits everything -atomically at the final step. State is carried in a signed cookie. -""" - -import base64 -from dataclasses import dataclass, field, asdict -from typing import Any - -from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired -from starlette.requests import Request -from starlette.responses import Response - - -# 1 hour max age for wizard cookie -WIZARD_MAX_AGE = 3600 -WIZARD_COOKIE = "central_wizard" - - -@dataclass -class WizardOperator: - """Operator data collected in step 1.""" - username: str - password_hash: str - - -@dataclass -class WizardSystem: - """System settings collected in step 2.""" - map_tile_url: str - map_attribution: str - - -@dataclass -class WizardApiKey: - """API key collected in step 3.""" - alias: str - encrypted_value_b64: str # base64-encoded encrypted value - - -@dataclass -class WizardAdapter: - """Adapter config collected in step 4.""" - enabled: bool - cadence_s: int - settings: dict[str, Any] - - -@dataclass -class WizardState: - """Complete wizard state carried across all steps.""" - wizard_step: int = 1 - operator: dict | None = None - system: dict | None = None - api_keys: list[dict] = field(default_factory=list) - adapters: dict[str, dict] | None = None - - def to_dict(self) -> dict: - """Convert to dictionary for serialization.""" - return { - "wizard_step": self.wizard_step, - "operator": self.operator, - "system": self.system, - "api_keys": self.api_keys, - "adapters": self.adapters, - } - - @classmethod - def from_dict(cls, data: dict) -> "WizardState": - """Create from dictionary.""" - return cls( - wizard_step=data.get("wizard_step", 1), - operator=data.get("operator"), - system=data.get("system"), - api_keys=data.get("api_keys", []), - adapters=data.get("adapters"), - ) - - -def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer: - """Get a timed serializer for wizard state.""" - return URLSafeTimedSerializer(secret_key, salt="wizard-state") - - -def get_wizard_state(request: Request, secret_key: str) -> WizardState | None: - """Decode wizard state from cookie. - - Returns WizardState if valid, None if missing/invalid/expired. - """ - cookie_value = request.cookies.get(WIZARD_COOKIE) - if not cookie_value: - return None - - serializer = _get_wizard_serializer(secret_key) - try: - data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE) - return WizardState.from_dict(data) - except (BadSignature, SignatureExpired): - return None - - -def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None: - """Set the wizard state cookie on a response.""" - serializer = _get_wizard_serializer(secret_key) - signed_value = serializer.dumps(state.to_dict()) - response.set_cookie( - WIZARD_COOKIE, - signed_value, - max_age=WIZARD_MAX_AGE, - path="/setup", - httponly=True, - samesite="lax", - ) - - -def clear_wizard_cookie(response: Response) -> None: - """Remove the wizard state cookie.""" - response.delete_cookie(WIZARD_COOKIE, path="/setup") - - -def get_step_route(step: int) -> str: - """Get the route for a wizard step number.""" - routes = { - 1: "/setup/operator", - 2: "/setup/system", - 3: "/setup/keys", - 4: "/setup/adapters", - 5: "/setup/finish", - } - return routes.get(step, "/setup/operator") diff --git a/tests/test_wizard.py b/tests/test_wizard.py index 0492276..dcaa7fe 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,4 +1,4 @@ -"""Tests for the first-run setup wizard with deferred-commit pattern.""" +"""Tests for the first-run setup wizard.""" from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch @@ -11,38 +11,60 @@ from central.gui.routes import ( setup_system_submit, setup_keys_form, setup_keys_submit, + setup_adapters_form, + setup_adapters_submit, setup_finish_form, setup_finish_submit, ) -from central.gui.middleware import SetupGateMiddleware -from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie +from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step class TestWizardStepRedirect: - """Test wizard step redirect logic based on cookie state.""" + """Test wizard step redirect logic.""" - def test_no_cookie_redirects_to_operator(self): - """When no wizard cookie exists, redirect to /setup/operator.""" - from central.gui.middleware import _get_wizard_redirect_from_cookie + @pytest.mark.asyncio + async def test_no_operators_redirects_to_operator(self): + """When no operators exist, redirect to /setup/operator.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [0] # No operators - mock_request = MagicMock() - mock_request.cookies = {} - - result = _get_wizard_redirect_from_cookie(mock_request, "testsecret") + result = await _get_wizard_redirect_step(mock_conn) assert result == "/setup/operator" - def test_cookie_step_2_redirects_to_system(self): - """When wizard_step=2 in cookie, redirect to /setup/system.""" - from central.gui.wizard import get_step_route + @pytest.mark.asyncio + async def test_default_tile_url_redirects_to_system(self): + """When map_tile_url is default, redirect to /setup/system.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1] # Has operator + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + } - result = get_step_route(2) + result = await _get_wizard_redirect_step(mock_conn) assert result == "/setup/system" - def test_cookie_step_5_redirects_to_finish(self): - """When wizard_step=5 in cookie, redirect to /setup/finish.""" - from central.gui.wizard import get_step_route + @pytest.mark.asyncio + async def test_no_adapters_touched_redirects_to_keys(self): + """When no adapters have been updated, redirect to /setup/keys.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } - result = get_step_route(5) + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/keys" + + @pytest.mark.asyncio + async def test_all_steps_complete_redirects_to_finish(self): + """When all steps done, redirect to /setup/finish.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) assert result == "/setup/finish" @@ -50,26 +72,63 @@ class TestSetupOperatorForm: """Test operator creation form (step 1).""" @pytest.mark.asyncio - async def test_get_returns_form_without_prefill(self): - """GET /setup/operator returns the form when no wizard cookie exists.""" + async def test_get_returns_form(self): + """GET /setup/operator returns the form when no operator exists.""" mock_request = MagicMock() - mock_request.cookies = {} mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = None # No operator exists + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" - with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert "csrf_token" in context and context["csrf_token"] assert context["error"] is None - assert context["form_data"] is None + assert context["existing_operator"] is None + + @pytest.mark.asyncio + async def test_get_returns_confirmation_when_operator_exists(self): + """GET /setup/operator shows confirmation when operator already exists.""" + mock_request = MagicMock() + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.body = b"Operator Already Configured" + mock_templates.TemplateResponse.return_value = mock_response + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) + + mock_templates.TemplateResponse.assert_called_once() + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["existing_operator"] == {"username": "admin"} + assert context["error"] is None class TestSetupOperatorSubmit: @@ -79,17 +138,28 @@ class TestSetupOperatorSubmit: async def test_password_mismatch_shows_error(self): """POST with password mismatch re-renders with error.""" mock_request = MagicMock() - mock_request.cookies = {} - mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) - + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password1", + "confirm_password": "password2", # Mismatch + }) mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 0 # No existing operators + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" - with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" result = await setup_operator_submit( mock_request, username="testuser", @@ -102,43 +172,374 @@ class TestSetupOperatorSubmit: assert context["error"] == "Passwords do not match" @pytest.mark.asyncio - async def test_valid_creates_wizard_cookie_and_redirects(self): - """POST with valid data creates wizard cookie and redirects to /setup/system.""" + async def test_valid_creates_operator_and_redirects(self): + """POST with valid data creates operator and redirects to /setup/system.""" mock_request = MagicMock() - mock_request.cookies = {} - mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" - with patch("central.gui.routes.hash_password", return_value="hashed_pw"): - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 0 # No existing operators + mock_conn.fetchrow.side_effect = [ + {"id": 1}, # INSERT RETURNING id + {"session_lifetime_days": 90}, # system settings + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.hash_password", return_value="hashed"): + with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: + mock_session.return_value = ("session_token", datetime.now(), "csrf_token") + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) assert result.status_code == 302 assert result.headers["location"] == "/setup/system" + @pytest.mark.asyncio + async def test_post_when_operator_exists_shows_confirmation(self): + """POST when operator exists returns 200 with confirmation, no insert.""" + mock_request = MagicMock() + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + mock_conn = AsyncMock() + mock_conn.fetchval.return_value = 1 # Operator already exists + mock_conn.fetchrow.return_value = {"username": "existing_admin"} + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_request.state.csrf_token = "test_csrf" + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) + + # Should return 200, not 500 or redirect + assert result.status_code == 200 + + # Should render confirmation state + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["existing_operator"] == {"username": "existing_admin"} + + # Should NOT call write_audit (no insert happened) + mock_audit.assert_not_called() + class TestSetupSystemForm: """Test system settings form (step 2).""" @pytest.mark.asyncio - async def test_no_wizard_cookie_redirects_to_operator(self): - """GET /setup/system without wizard cookie redirects to /setup/operator.""" + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/system without auth redirects to /setup/operator.""" mock_request = MagicMock() - mock_request.cookies = {} - - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" - result = await setup_system_form(mock_request) - + mock_request.state.operator = None + result = await setup_system_form(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" + @pytest.mark.asyncio + async def test_authenticated_returns_form(self): + """GET /setup/system with auth returns the form.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "© OpenStreetMap contributors", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_form(mock_request) + + mock_templates.TemplateResponse.assert_called_once() + + +class TestSetupSystemSubmit: + """Test system settings submission.""" + + @pytest.mark.asyncio + async def test_missing_placeholders_shows_error(self): + """POST without {z},{x},{y} placeholders shows error.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "map_tile_url": "https://example.com/tiles", + "map_attribution": "Test", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "", + "map_attribution": "", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_submit(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert "map_tile_url" in context["errors"] + + @pytest.mark.asyncio + async def test_valid_updates_and_redirects(self): + """POST with valid data updates system and redirects to /setup/keys.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "map_tile_url": "https://example.com/{z}/{x}/{y}.png", + "map_attribution": "Test Attribution", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "old_url", + "map_attribution": "old_attr", + } + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_system_submit(mock_request) + + assert result.status_code == 302 + assert result.headers["location"] == "/setup/keys" + + +class TestSetupKeysForm: + """Test API keys form (step 3).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/keys without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_keys_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupKeysSubmit: + """Test API keys submission.""" + + @pytest.mark.asyncio + async def test_next_action_redirects_to_adapters(self): + """POST with action=next redirects to /setup/adapters.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "action": "next", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + # No need to mock get_pool since action="next" returns before it's called + result = await setup_keys_submit(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/adapters" + + @pytest.mark.asyncio + async def test_add_key_creates_and_rerenders(self): + """POST with action=add creates key and re-renders with success.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "action": "add", + "alias": "testkey", + "plaintext_key": "secret123", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.side_effect = [ + None, # No existing key + {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, + ] + mock_conn.fetch.side_effect = [ + [], # First list + [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.crypto.encrypt", return_value=b"encrypted"): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_keys_submit(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["success"] == "API key 'testkey' added successfully." + + +class TestSetupAdaptersForm: + """Test adapters configuration form (step 4).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/adapters without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_adapters_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupFinishForm: + """Test finish page (step 5).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/finish without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + result = await setup_finish_form(mock_request) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_authenticated_shows_summary(self): + """GET /setup/finish with auth shows summary.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys + mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} + mock_conn.fetch.return_value = [ + {"name": "nws", "enabled": True, "cadence_s": 300}, + {"name": "firms", "enabled": False, "cadence_s": 600}, + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_finish_form(mock_request) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["operator_count"] == 1 + assert context["key_count"] == 2 + assert len(context["adapters"]) == 2 + + +class TestSetupFinishSubmit: + """Test setup completion.""" + + @pytest.mark.asyncio + async def test_marks_setup_complete_and_redirects(self): + """POST /setup/finish marks setup_complete=true and redirects to /.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + # Mock form with CSRF token + form_data = MagicMock() + form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_finish_submit(mock_request) + + assert result.status_code == 302 + assert result.headers["location"] == "/" + mock_conn.execute.assert_called_once() + mock_audit.assert_called_once() + class TestSetupGateMiddlewareWizard: """Test SetupGateMiddleware with wizard paths.""" @@ -169,6 +570,69 @@ class TestSetupGateMiddlewareWizard: response = client.get("/setup/operator") assert response.status_code == 200 + @pytest.mark.asyncio + async def test_redirects_base_setup_to_wizard_step(self): + """SetupGateMiddleware redirects /setup to appropriate wizard step.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.fetchval = AsyncMock(return_value=0) # No operators + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup") + async def setup(): + return {"message": "base setup"} + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup") + assert response.status_code == 302 + assert response.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_redirects_login_to_setup_when_incomplete(self): + """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/login") + async def login(): + return {"message": "login"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/login") + assert response.status_code == 302 + assert response.headers["location"] == "/setup" + @pytest.mark.asyncio async def test_redirects_all_setup_paths_when_complete(self): """SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""