diff --git a/src/central/gui/csrf.py b/src/central/gui/csrf.py index 0d6198f..37848cd 100644 --- a/src/central/gui/csrf.py +++ b/src/central/gui/csrf.py @@ -1,11 +1,10 @@ -"""Pre-auth CSRF protection for login and setup/operator pages. +"""Pre-auth CSRF protection for login and setup pages. These routes cannot use session-bound CSRF because no session exists yet. Uses a simple cookie-based pattern with short-lived tokens. """ import secrets -from typing import Optional from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from starlette.requests import Request @@ -34,6 +33,34 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]: return plain_token, signed_token +def reuse_or_generate_pre_auth_csrf( + request: Request, + secret_key: str, +) -> tuple[str, str | None]: + """Reuse an existing valid pre-auth CSRF token, or generate new. + + Returns (plain_token, signed_token_for_cookie). + If signed_token_for_cookie is None, the existing cookie is + still valid and caller should not call set_pre_auth_csrf_cookie. + If non-None, caller MUST call set_pre_auth_csrf_cookie with + it to persist the new value. + """ + cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE) + if cookie_value: + serializer = _get_serializer(secret_key) + try: + plain_token = serializer.loads( + cookie_value, + max_age=PRE_AUTH_CSRF_MAX_AGE, + ) + return plain_token, None # reuse existing + except (BadSignature, SignatureExpired): + pass # fall through to generate + + plain_token, signed_token = generate_pre_auth_csrf(secret_key) + return plain_token, signed_token + + def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None: """Set the pre-auth CSRF cookie on a response.""" response.set_cookie( diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index 155112b..e41fc07 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -16,7 +16,7 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} -AUTH_EXEMPT_PREFIXES = ("/static/",) +AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/") def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: @@ -29,33 +29,14 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False -async def _get_wizard_redirect_step(conn) -> str: - """Determine which wizard step to redirect to based on DB state.""" - # Check if any operators exist - op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - if op_count == 0: +def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str: + """Determine wizard redirect step from cookie state.""" + from central.gui.wizard import get_wizard_state, get_step_route + + state = get_wizard_state(request, csrf_secret) + if state is None: return "/setup/operator" - - # Check if system settings have been configured (map_tile_url not default) - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - if sys_row is None or sys_row["map_tile_url"] == default_tile: - return "/setup/system" - - # Keys step is optional, so check adapters have been reviewed - # We consider adapters reviewed if any adapter has a non-null updated_at - # (meaning it was explicitly saved during setup) - adapters_touched = await conn.fetchval( - "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" - ) - if adapters_touched == 0: - # Go to keys first, then adapters - return "/setup/keys" - - # All steps done, go to finish - return "/setup/finish" + return get_step_route(state.wizard_step) class SetupGateMiddleware(BaseHTTPMiddleware): @@ -85,13 +66,16 @@ class SetupGateMiddleware(BaseHTTPMiddleware): if not setup_complete: # Setup not complete - only allow setup paths and static/health if path.startswith("/setup"): - # Allow all /setup/* paths (handler will enforce auth) + # Allow all /setup/* paths # But /setup with no subpath should redirect to appropriate step if path == "/setup" or path == "/setup/": try: - async with pool.acquire() as conn: - redirect_step = await _get_wizard_redirect_step(conn) - return RedirectResponse(url=redirect_step, status_code=302) + from central.bootstrap_config import get_settings + settings = get_settings() + redirect_step = _get_wizard_redirect_from_cookie( + request, settings.csrf_secret + ) + return RedirectResponse(url=redirect_step, status_code=302) except Exception: logger.warning("Failed to determine wizard step", exc_info=True) return RedirectResponse(url="/setup/operator", status_code=302) @@ -139,7 +123,7 @@ class SessionMiddleware(BaseHTTPMiddleware): request.state.operator = None request.state.csrf_token = None - # Check if auth is required + # Check if auth is required - setup paths are exempt during wizard if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): if request.state.operator is None: return RedirectResponse(url="/login", status_code=302) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 0d08e77..f86b6e1 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -13,6 +13,7 @@ from fastapi import APIRouter, Depends, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response from central.bootstrap_config import get_settings from central.gui.csrf import ( + reuse_or_generate_pre_auth_csrf, generate_pre_auth_csrf, set_pre_auth_csrf_cookie, validate_pre_auth_csrf, @@ -49,6 +50,9 @@ router = APIRouter() # Streams to display on dashboard DASHBOARD_STREAMS = ["CENTRAL_WX", "CENTRAL_FIRE", "CENTRAL_QUAKE", "CENTRAL_META"] +# Email validation regex (simple but effective) +ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$") + # Email validation regex (simple but effective) EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") @@ -267,24 +271,29 @@ async def dashboard_polls(request: Request) -> HTMLResponse: # ============================================================================= -@router.get("/setup/operator", response_class=HTMLResponse) -async def setup_operator_form( - request: Request, -) -> HTMLResponse: - """Render the setup operator form (step 1).""" - templates = _get_templates() - pool = get_pool() - settings = get_settings() - csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) +# ============================================================================= +# Setup Wizard routes (deferred-commit pattern) +# ============================================================================= - # Check if operator already exists - existing_operator = None - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - if row: - existing_operator = {"username": row["username"]} + +@router.get("/setup/operator", response_class=HTMLResponse) +async def setup_operator_form(request: Request) -> HTMLResponse: + """Render the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + templates = _get_templates() + settings = get_settings() + + # Get wizard state from cookie (if any) + state = get_wizard_state(request, settings.csrf_secret) + + # Pre-fill from cookie state if available + form_data = None + if state and state.operator: + form_data = {"username": state.operator.get("username", "")} + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, @@ -292,11 +301,11 @@ async def setup_operator_form( context={ "csrf_token": csrf_token, "error": None, - "form_data": None, - "existing_operator": existing_operator, + "form_data": form_data, }, ) - set_pre_auth_csrf_cookie(response, signed_token) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @@ -306,39 +315,22 @@ async def setup_operator_submit( username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), - ) -> Response: """Process the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state, set_wizard_cookie, WizardState + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + templates = _get_templates() - pool = get_pool() + settings = get_settings() # Validate CSRF - settings = get_settings() form = await request.form() form_csrf = form.get("csrf_token", "") if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - # Check if operator already exists (single-operator-per-install design) - async with pool.acquire() as conn: - count = await conn.fetchval("SELECT count(*) FROM config.operators") - if count > 0: - # Operator already exists — render confirmation page - existing = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_operator.html", - context={ - "csrf_token": csrf_token, - "error": None, - "form_data": None, - "existing_operator": {"username": existing["username"]}, - }, - ) - return response + # Get or create wizard state + state = get_wizard_state(request, settings.csrf_secret) or WizardState() # Validate input error = None @@ -351,7 +343,7 @@ async def setup_operator_submit( error = str(e) if error: - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_operator.html", @@ -359,73 +351,54 @@ async def setup_operator_submit( "csrf_token": csrf_token, "error": error, "form_data": {"username": username}, - "existing_operator": None, }, status_code=200, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response - # Create operator + # Hash password and store in wizard state (NO DB write) password_hash = hash_password(password) - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - INSERT INTO config.operators (username, password_hash) - VALUES ($1, $2) - RETURNING id - """, - username, - password_hash, - ) - operator_id = row["id"] + state.operator = {"username": username, "password_hash": password_hash} + state.wizard_step = max(state.wizard_step, 2) - # Write audit log - await write_audit( - conn, - OPERATOR_CREATE, - operator_id=operator_id, - target=username, - ) - - # Get session lifetime - sysrow = await conn.fetchrow( - "SELECT session_lifetime_days FROM config.system WHERE id = true" - ) - lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 - - # Create session - token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) - - # Redirect to next step with session cookie + # Redirect to next step with updated wizard cookie response = RedirectResponse(url="/setup/system", status_code=302) - _set_session_cookie(response, token, lifetime_days * 86400) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/system", response_class=HTMLResponse) -async def setup_system_form( - request: Request, - -) -> HTMLResponse: +async def setup_system_form(request: Request) -> HTMLResponse: """Render the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required for step 2+ + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", - } + # Pre-fill from cookie state or DB defaults + if state.system: + system = state.system + else: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_system.html", @@ -437,29 +410,31 @@ async def setup_system_form( "system": system, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/system") -async def setup_system_submit( - request: Request, - -) -> Response: +async def setup_system_submit(request: Request) -> Response: """Process the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() - pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() map_tile_url = form.get("map_tile_url", "").strip() map_attribution = form.get("map_attribution", "").strip() @@ -480,87 +455,52 @@ async def setup_system_submit( if not map_attribution: errors["map_attribution"] = "Map attribution is required" - async with pool.acquire() as conn: - if errors: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "", - "map_attribution": row["map_attribution"] if row else "", - } - - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_system.html", - context={ - "csrf_token": csrf_token, - "error": None, - "errors": errors, - "form_data": form_data, - "system": system, - }, - status_code=200, - ) - return response - - # Get current values for audit - old_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": state.system or form_data, + }, + status_code=200, ) - before = { - "map_tile_url": old_row["map_tile_url"] if old_row else None, - "map_attribution": old_row["map_attribution"] if old_row else None, - } + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - # Update system settings - await conn.execute( - """ - UPDATE config.system - SET map_tile_url = $1, map_attribution = $2 - WHERE id = true - """, - map_tile_url, - map_attribution, - ) + # Update wizard state (NO DB write) + state.system = {"map_tile_url": map_tile_url, "map_attribution": map_attribution} + state.wizard_step = max(state.wizard_step, 3) - # Write audit log - await write_audit( - conn, - SYSTEM_UPDATE, - operator_id=operator.id, - target="system", - before=before, - after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, - ) - - return RedirectResponse(url="/setup/keys", status_code=302) + response = RedirectResponse(url="/setup/keys", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/keys", response_class=HTMLResponse) -async def setup_keys_form( - request: Request, - -) -> HTMLResponse: +async def setup_keys_form(request: Request) -> HTMLResponse: """Render the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) - from central.crypto import encrypt - templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + # Keys come from cookie state (not DB) + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -572,36 +512,40 @@ async def setup_keys_form( "success": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/keys") -async def setup_keys_submit( - request: Request, - -) -> Response: +async def setup_keys_submit(request: Request) -> Response: """Process the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) - - form = await request.form() - form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: - raise CsrfValidationError("Invalid CSRF token") - - form = await request.form() - action = form.get("action", "add") - - # If action is "next", redirect to adapters step - if action == "next": - return RedirectResponse(url="/setup/adapters", status_code=302) - + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf from central.crypto import encrypt templates = _get_templates() - pool = get_pool() + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + raise CsrfValidationError("Invalid CSRF token") + + action = form.get("action", "add") + + # If action is "next", advance to adapters step + if action == "next": + state.wizard_step = max(state.wizard_step, 4) + response = RedirectResponse(url="/setup/adapters", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response # Otherwise, add a new key alias = form.get("alias", "").strip() @@ -617,6 +561,8 @@ async def setup_keys_submit( errors["alias"] = "Alias must be at most 64 characters" elif not ALIAS_REGEX.match(alias): errors["alias"] = "Alias must contain only letters, numbers, and underscores" + elif any(k["alias"] == alias for k in state.api_keys): + errors["alias"] = "An API key with this alias already exists" # Validate plaintext_key if not plaintext_key: @@ -624,69 +570,34 @@ async def setup_keys_submit( elif len(plaintext_key) > 4096: errors["plaintext_key"] = "API key must be at most 4096 characters" - async with pool.acquire() as conn: - if not errors: - # Check if alias already exists - existing = await conn.fetchrow( - "SELECT alias FROM config.api_keys WHERE alias = $1", - alias, - ) - if existing: - errors["alias"] = "An API key with this alias already exists" + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - if errors: - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_keys.html", - context={ - "csrf_token": csrf_token, - "keys": keys, - "errors": errors, - "form_data": form_data, - "success": None, - }, - status_code=200, - ) - return response - - # Encrypt the key - encrypted_value = encrypt(plaintext_key.encode()) - - # Insert the new key - row = await conn.fetchrow( - """ - INSERT INTO config.api_keys (alias, encrypted_value) - VALUES ($1, $2) - RETURNING created_at - """, - alias, - encrypted_value, - ) - - # Write audit log (no plaintext!) - await write_audit( - conn, - API_KEY_CREATE, - operator_id=operator.id, - target=alias, - before=None, - after={"alias": alias, "created_at": row["created_at"].isoformat()}, - ) - - # Refresh keys list - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + # Encrypt the key and add to state (NO DB write) + encrypted_value = encrypt(plaintext_key.encode()) + encrypted_b64 = base64.b64encode(encrypted_value).decode() + state.api_keys.append({"alias": alias, "encrypted_value_b64": encrypted_b64}) # Re-render with success message - csrf_token = request.state.csrf_token + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -698,61 +609,82 @@ async def setup_keys_submit( "success": f"API key '{alias}' added successfully.", }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/adapters", response_class=HTMLResponse) -async def setup_adapters_form( - request: Request, - -) -> HTMLResponse: +async def setup_adapters_form(request: Request) -> HTMLResponse: """Render the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) + # Pre-fill from cookie state or DB defaults + if state.adapters: adapters = [] - for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - "settings": settings, - }) + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({ + "name": name, + "enabled": a["enabled"], + "cadence_s": a["cadence_s"], + "settings": a["settings"], + }) + else: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + adapters = [] + for row in rows: + settings_data = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings_data, + }) - # Get API keys for dropdown - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + # Get API keys from wizard state (not DB) + api_keys = [{"alias": k["alias"]} for k in state.api_keys] - # Get map tile settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Get map tile settings from wizard state or DB + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + async with pool.acquire() as conn: + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_adapters.html", context={ "csrf_token": csrf_token, "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], + "api_keys": api_keys, "valid_satellites": _get_valid_satellites(), "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, @@ -762,242 +694,216 @@ async def setup_adapters_form( "form_data": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/adapters") -async def setup_adapters_submit( - request: Request, - -) -> Response: +async def setup_adapters_submit(request: Request) -> Response: """Process the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() errors: dict[str, str] = {} + new_adapters: dict[str, dict] = {} - async with pool.acquire() as conn: - # Get current adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) - - for row in rows: - adapter_name = row["name"] - current_settings = row["settings"] or {} - new_settings = dict(current_settings) - - # Parse enabled - enabled = f"{adapter_name}_enabled" in form - - # Parse cadence - cadence_str = form.get(f"{adapter_name}_cadence_s", "") - try: - cadence_s = int(cadence_str) - if cadence_s < 60 or cadence_s > 3600: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" - except ValueError: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" - cadence_s = row["cadence_s"] - - # Adapter-specific validation - if adapter_name == "nws": - contact_email = form.get(f"{adapter_name}_contact_email", "").strip() - if enabled: - if not contact_email: - errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(contact_email): - errors[f"{adapter_name}_contact_email"] = "Invalid email format" - else: - new_settings["contact_email"] = contact_email - else: - new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") - - elif adapter_name == "firms": - api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() - satellites = form.getlist(f"{adapter_name}_satellites") - - if api_key_alias: - key_exists = await conn.fetchrow( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - api_key_alias, - ) - if not key_exists: - errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" - else: - new_settings["api_key_alias"] = api_key_alias - else: - new_settings["api_key_alias"] = None - - # Validate satellites - valid_sats = set(_get_valid_satellites()) - invalid_sats = [s for s in satellites if s not in valid_sats] - if invalid_sats: - errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" - else: - new_settings["satellites"] = satellites - - elif adapter_name == "usgs_quake": - feed = form.get(f"{adapter_name}_feed", "").strip() - valid_feeds = _get_valid_feeds() - if feed not in valid_feeds: - errors[f"{adapter_name}_feed"] = f"Invalid feed" - else: - new_settings["feed"] = feed - - # Region validation - region_north_str = form.get(f"{adapter_name}_region_north", "").strip() - region_south_str = form.get(f"{adapter_name}_region_south", "").strip() - region_east_str = form.get(f"{adapter_name}_region_east", "").strip() - region_west_str = form.get(f"{adapter_name}_region_west", "").strip() - - try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) - - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" - else: - new_settings["region"] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" - - # Store parsed data for re-render on error or update - if not errors.get(f"{adapter_name}_cadence_s"): - # Update adapter - await conn.execute( - """ - UPDATE config.adapters - SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() - WHERE name = $4 - """, - enabled, - cadence_s, - new_settings, - adapter_name, - ) - - # If any errors, re-render - if errors: - adapters = [] + # Get current adapter configs from state or DB as baseline + if state.adapters: + current_adapters = state.adapters + else: + async with pool.acquire() as conn: rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ + "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" ) + current_adapters = {} for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], + current_adapters[row["name"]] = { "enabled": row["enabled"], "cadence_s": row["cadence_s"], - "settings": settings, - }) + "settings": row["settings"] or {}, + } - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + for adapter_name in ["firms", "nws", "usgs_quake"]: + current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) + current_settings = current.get("settings", {}) + new_settings = dict(current_settings) - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Parse enabled + enabled = f"{adapter_name}_enabled" in form - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_adapters.html", - context={ - "csrf_token": csrf_token, - "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), - "tile_url": tile_url, - "tile_attribution": tile_attribution, - "error": "Please fix the errors below.", - "errors": errors, - "form_data": form, - }, - status_code=200, - ) - return response + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = current.get("cadence_s", 300) - return RedirectResponse(url="/setup/finish", status_code=302) + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + # Validate against wizard state keys + if not any(k["alias"] == api_key_alias for k in state.api_keys): + errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = "Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation (all adapters) + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + new_adapters[adapter_name] = { + "enabled": enabled, + "cadence_s": cadence_s, + "settings": new_settings, + } + + # If errors, re-render + if errors: + adapters = [ + {"name": name, "enabled": new_adapters[name]["enabled"], + "cadence_s": new_adapters[name]["cadence_s"], + "settings": new_adapters[name]["settings"]} + for name in ["firms", "nws", "usgs_quake"] + ] + api_keys = [{"alias": k["alias"]} for k in state.api_keys] + + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = "© OpenStreetMap contributors" + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": api_keys, + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + # Update wizard state (NO DB write) + state.adapters = new_adapters + state.wizard_step = max(state.wizard_step, 5) + + response = RedirectResponse(url="/setup/finish", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/finish", response_class=HTMLResponse) -async def setup_finish_form( - request: Request, - -) -> HTMLResponse: +async def setup_finish_form(request: Request) -> HTMLResponse: """Render the finish setup page (step 5).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - # Get counts - operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + operator_count = 1 if state.operator else 0 + key_count = len(state.api_keys) + system = state.system or {"map_tile_url": "(not configured)"} - # Get system settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": sys_row["map_tile_url"] if sys_row else "", - } + adapters = [] + if state.adapters: + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) - # Get adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s - FROM config.adapters - ORDER BY name - """ - ) - adapters = [ - { - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - } - for row in rows - ] - - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_finish.html", @@ -1007,46 +913,101 @@ async def setup_finish_form( "key_count": key_count, "system": system, "adapters": adapters, + "error": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/finish") -async def setup_finish_submit( - request: Request, - -) -> Response: - """Complete the setup wizard.""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) +async def setup_finish_submit(request: Request) -> Response: + """Complete the setup wizard - atomic commit of all wizard state.""" + from central.gui.wizard import get_wizard_state, clear_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + from asyncpg.exceptions import UniqueViolationError + templates = _get_templates() pool = get_pool() + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - async with pool.acquire() as conn: - # Mark setup complete - await conn.execute( - "UPDATE config.system SET setup_complete = true WHERE id = true" - ) + if not state.system: + return RedirectResponse(url="/setup/system", status_code=302) + if not state.adapters: + return RedirectResponse(url="/setup/adapters", status_code=302) - # Write audit log - await write_audit( - conn, - SETUP_COMPLETE, - operator_id=operator.id, - target="system", - ) + try: + async with pool.acquire() as conn: + async with conn.transaction(): + # 1. INSERT operator + op_row = await conn.fetchrow( + "INSERT INTO config.operators (username, password_hash) VALUES ($1, $2) RETURNING id", + state.operator["username"], + state.operator["password_hash"], + ) + operator_id = op_row["id"] - return RedirectResponse(url="/", status_code=302) + await write_audit(conn, OPERATOR_CREATE, operator_id=operator_id, target=state.operator["username"]) + # 2. Create session + sysrow = await conn.fetchrow("SELECT session_lifetime_days FROM config.system WHERE id = true") + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) + # 3. UPDATE config.system + old_sys = await conn.fetchrow("SELECT map_tile_url, map_attribution FROM config.system WHERE id = true") + await conn.execute( + "UPDATE config.system SET map_tile_url = $1, map_attribution = $2, setup_complete = true WHERE id = true", + state.system["map_tile_url"], + state.system["map_attribution"], + ) + await write_audit(conn, SYSTEM_UPDATE, operator_id=operator_id, target="system", + before={"map_tile_url": old_sys["map_tile_url"], "map_attribution": old_sys["map_attribution"]} if old_sys else None, + after={"map_tile_url": state.system["map_tile_url"], "map_attribution": state.system["map_attribution"]}) + + # 4. INSERT each API key + for key in state.api_keys: + encrypted = base64.b64decode(key["encrypted_value_b64"]) + await conn.execute("INSERT INTO config.api_keys (alias, encrypted_value) VALUES ($1, $2)", key["alias"], encrypted) + await write_audit(conn, API_KEY_CREATE, operator_id=operator_id, target=key["alias"]) + + # 5. UPDATE config.adapters + for name, adapter_cfg in state.adapters.items(): + old_adapter = await conn.fetchrow("SELECT enabled, cadence_s, settings FROM config.adapters WHERE name = $1", name) + await conn.execute( + "UPDATE config.adapters SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() WHERE name = $4", + adapter_cfg["enabled"], adapter_cfg["cadence_s"], adapter_cfg["settings"], name) + await write_audit(conn, ADAPTER_UPDATE, operator_id=operator_id, target=name, + before={"enabled": old_adapter["enabled"], "cadence_s": old_adapter["cadence_s"]} if old_adapter else None, + after={"enabled": adapter_cfg["enabled"], "cadence_s": adapter_cfg["cadence_s"]}) + + await write_audit(conn, SETUP_COMPLETE, operator_id=operator_id, target="system") + + except UniqueViolationError: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse(request=request, name="setup_finish.html", + context={"csrf_token": csrf_token, "operator_count": 1, "key_count": len(state.api_keys), + "system": state.system, "adapters": [{"name": n, "enabled": a["enabled"], "cadence_s": a["cadence_s"]} for n, a in state.adapters.items()], + "error": f"Username '{state.operator['username']}' already exists."}, status_code=200) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + response = RedirectResponse(url="/", status_code=302) + clear_wizard_cookie(response) + unset_pre_auth_csrf_cookie(response) + _set_session_cookie(response, token, lifetime_days * 86400) + return response @router.get("/login", response_class=HTMLResponse) async def login_form( request: Request, diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html index 36932b4..f4e9277 100644 --- a/src/central/gui/templates/setup_operator.html +++ b/src/central/gui/templates/setup_operator.html @@ -7,17 +7,6 @@ {% include "_wizard_header.html" %} {% endwith %} -{% if existing_operator %} -
-
-

Operator Already Configured

-
-

The operator account {{ existing_operator.username }} has been created.

-
- Next → -
-
-{% else %}

Create Operator Account

@@ -53,5 +42,4 @@
-{% endif %} {% endblock %} diff --git a/src/central/gui/wizard.py b/src/central/gui/wizard.py new file mode 100644 index 0000000..8f28ac3 --- /dev/null +++ b/src/central/gui/wizard.py @@ -0,0 +1,131 @@ +"""Wizard state management for deferred-commit setup flow. + +The wizard collects configuration across 5 steps and commits everything +atomically at the final step. State is carried in a signed cookie. +""" + +import base64 +from dataclasses import dataclass, field, asdict +from typing import Any + +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from starlette.requests import Request +from starlette.responses import Response + + +# 1 hour max age for wizard cookie +WIZARD_MAX_AGE = 3600 +WIZARD_COOKIE = "central_wizard" + + +@dataclass +class WizardOperator: + """Operator data collected in step 1.""" + username: str + password_hash: str + + +@dataclass +class WizardSystem: + """System settings collected in step 2.""" + map_tile_url: str + map_attribution: str + + +@dataclass +class WizardApiKey: + """API key collected in step 3.""" + alias: str + encrypted_value_b64: str # base64-encoded encrypted value + + +@dataclass +class WizardAdapter: + """Adapter config collected in step 4.""" + enabled: bool + cadence_s: int + settings: dict[str, Any] + + +@dataclass +class WizardState: + """Complete wizard state carried across all steps.""" + wizard_step: int = 1 + operator: dict | None = None + system: dict | None = None + api_keys: list[dict] = field(default_factory=list) + adapters: dict[str, dict] | None = None + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "wizard_step": self.wizard_step, + "operator": self.operator, + "system": self.system, + "api_keys": self.api_keys, + "adapters": self.adapters, + } + + @classmethod + def from_dict(cls, data: dict) -> "WizardState": + """Create from dictionary.""" + return cls( + wizard_step=data.get("wizard_step", 1), + operator=data.get("operator"), + system=data.get("system"), + api_keys=data.get("api_keys", []), + adapters=data.get("adapters"), + ) + + +def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer: + """Get a timed serializer for wizard state.""" + return URLSafeTimedSerializer(secret_key, salt="wizard-state") + + +def get_wizard_state(request: Request, secret_key: str) -> WizardState | None: + """Decode wizard state from cookie. + + Returns WizardState if valid, None if missing/invalid/expired. + """ + cookie_value = request.cookies.get(WIZARD_COOKIE) + if not cookie_value: + return None + + serializer = _get_wizard_serializer(secret_key) + try: + data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE) + return WizardState.from_dict(data) + except (BadSignature, SignatureExpired): + return None + + +def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None: + """Set the wizard state cookie on a response.""" + serializer = _get_wizard_serializer(secret_key) + signed_value = serializer.dumps(state.to_dict()) + response.set_cookie( + WIZARD_COOKIE, + signed_value, + max_age=WIZARD_MAX_AGE, + path="/setup", + httponly=True, + samesite="lax", + ) + + +def clear_wizard_cookie(response: Response) -> None: + """Remove the wizard state cookie.""" + response.delete_cookie(WIZARD_COOKIE, path="/setup") + + +def get_step_route(step: int) -> str: + """Get the route for a wizard step number.""" + routes = { + 1: "/setup/operator", + 2: "/setup/system", + 3: "/setup/keys", + 4: "/setup/adapters", + 5: "/setup/finish", + } + return routes.get(step, "/setup/operator") diff --git a/tests/test_wizard.py b/tests/test_wizard.py index dcaa7fe..0492276 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,4 +1,4 @@ -"""Tests for the first-run setup wizard.""" +"""Tests for the first-run setup wizard with deferred-commit pattern.""" from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch @@ -11,60 +11,38 @@ from central.gui.routes import ( setup_system_submit, setup_keys_form, setup_keys_submit, - setup_adapters_form, - setup_adapters_submit, setup_finish_form, setup_finish_submit, ) -from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step +from central.gui.middleware import SetupGateMiddleware +from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie class TestWizardStepRedirect: - """Test wizard step redirect logic.""" + """Test wizard step redirect logic based on cookie state.""" - @pytest.mark.asyncio - async def test_no_operators_redirects_to_operator(self): - """When no operators exist, redirect to /setup/operator.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [0] # No operators + def test_no_cookie_redirects_to_operator(self): + """When no wizard cookie exists, redirect to /setup/operator.""" + from central.gui.middleware import _get_wizard_redirect_from_cookie - result = await _get_wizard_redirect_step(mock_conn) + mock_request = MagicMock() + mock_request.cookies = {} + + result = _get_wizard_redirect_from_cookie(mock_request, "testsecret") assert result == "/setup/operator" - @pytest.mark.asyncio - async def test_default_tile_url_redirects_to_system(self): - """When map_tile_url is default, redirect to /setup/system.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1] # Has operator - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - } + def test_cookie_step_2_redirects_to_system(self): + """When wizard_step=2 in cookie, redirect to /setup/system.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(2) assert result == "/setup/system" - @pytest.mark.asyncio - async def test_no_adapters_touched_redirects_to_keys(self): - """When no adapters have been updated, redirect to /setup/keys.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } + def test_cookie_step_5_redirects_to_finish(self): + """When wizard_step=5 in cookie, redirect to /setup/finish.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) - assert result == "/setup/keys" - - @pytest.mark.asyncio - async def test_all_steps_complete_redirects_to_finish(self): - """When all steps done, redirect to /setup/finish.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } - - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(5) assert result == "/setup/finish" @@ -72,63 +50,26 @@ class TestSetupOperatorForm: """Test operator creation form (step 1).""" @pytest.mark.asyncio - async def test_get_returns_form(self): - """GET /setup/operator returns the form when no operator exists.""" + async def test_get_returns_form_without_prefill(self): + """GET /setup/operator returns the form when no wizard cookie exists.""" mock_request = MagicMock() + mock_request.cookies = {} mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = None # No operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert "csrf_token" in context and context["csrf_token"] assert context["error"] is None - assert context["existing_operator"] is None - - @pytest.mark.asyncio - async def test_get_returns_confirmation_when_operator_exists(self): - """GET /setup/operator shows confirmation when operator already exists.""" - mock_request = MagicMock() - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.body = b"Operator Already Configured" - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "admin"} - assert context["error"] is None + assert context["form_data"] is None class TestSetupOperatorSubmit: @@ -138,28 +79,17 @@ class TestSetupOperatorSubmit: async def test_password_mismatch_shows_error(self): """POST with password mismatch re-renders with error.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password1", - "confirm_password": "password2", # Mismatch - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) + mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")): result = await setup_operator_submit( mock_request, username="testuser", @@ -172,374 +102,43 @@ class TestSetupOperatorSubmit: assert context["error"] == "Passwords do not match" @pytest.mark.asyncio - async def test_valid_creates_operator_and_redirects(self): - """POST with valid data creates operator and redirects to /setup/system.""" + async def test_valid_creates_wizard_cookie_and_redirects(self): + """POST with valid data creates wizard cookie and redirects to /setup/system.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - mock_conn.fetchrow.side_effect = [ - {"id": 1}, # INSERT RETURNING id - {"session_lifetime_days": 90}, # system settings - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.hash_password", return_value="hashed"): - with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: - mock_session.return_value = ("session_token", datetime.now(), "csrf_token") - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.hash_password", return_value="hashed_pw"): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) assert result.status_code == 302 assert result.headers["location"] == "/setup/system" - @pytest.mark.asyncio - async def test_post_when_operator_exists_shows_confirmation(self): - """POST when operator exists returns 200 with confirmation, no insert.""" - mock_request = MagicMock() - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.status_code = 200 - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 1 # Operator already exists - mock_conn.fetchrow.return_value = {"username": "existing_admin"} - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - mock_request.state.csrf_token = "test_csrf" - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) - - # Should return 200, not 500 or redirect - assert result.status_code == 200 - - # Should render confirmation state - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "existing_admin"} - - # Should NOT call write_audit (no insert happened) - mock_audit.assert_not_called() - class TestSetupSystemForm: """Test system settings form (step 2).""" @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/system without auth redirects to /setup/operator.""" + async def test_no_wizard_cookie_redirects_to_operator(self): + """GET /setup/system without wizard cookie redirects to /setup/operator.""" mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_system_form(mock_request) + mock_request.cookies = {} + + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + result = await setup_system_form(mock_request) + assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" - @pytest.mark.asyncio - async def test_authenticated_returns_form(self): - """GET /setup/system with auth returns the form.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": "© OpenStreetMap contributors", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - - -class TestSetupSystemSubmit: - """Test system settings submission.""" - - @pytest.mark.asyncio - async def test_missing_placeholders_shows_error(self): - """POST without {z},{x},{y} placeholders shows error.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/tiles", - "map_attribution": "Test", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "", - "map_attribution": "", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert "map_tile_url" in context["errors"] - - @pytest.mark.asyncio - async def test_valid_updates_and_redirects(self): - """POST with valid data updates system and redirects to /setup/keys.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/{z}/{x}/{y}.png", - "map_attribution": "Test Attribution", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "old_url", - "map_attribution": "old_attr", - } - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_system_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/setup/keys" - - -class TestSetupKeysForm: - """Test API keys form (step 3).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/keys without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_keys_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupKeysSubmit: - """Test API keys submission.""" - - @pytest.mark.asyncio - async def test_next_action_redirects_to_adapters(self): - """POST with action=next redirects to /setup/adapters.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "next", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - # No need to mock get_pool since action="next" returns before it's called - result = await setup_keys_submit(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/adapters" - - @pytest.mark.asyncio - async def test_add_key_creates_and_rerenders(self): - """POST with action=add creates key and re-renders with success.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "add", - "alias": "testkey", - "plaintext_key": "secret123", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.side_effect = [ - None, # No existing key - {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, - ] - mock_conn.fetch.side_effect = [ - [], # First list - [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.crypto.encrypt", return_value=b"encrypted"): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_keys_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["success"] == "API key 'testkey' added successfully." - - -class TestSetupAdaptersForm: - """Test adapters configuration form (step 4).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/adapters without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_adapters_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupFinishForm: - """Test finish page (step 5).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/finish without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_finish_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_authenticated_shows_summary(self): - """GET /setup/finish with auth shows summary.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys - mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} - mock_conn.fetch.return_value = [ - {"name": "nws", "enabled": True, "cadence_s": 300}, - {"name": "firms", "enabled": False, "cadence_s": 600}, - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_finish_form(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["operator_count"] == 1 - assert context["key_count"] == 2 - assert len(context["adapters"]) == 2 - - -class TestSetupFinishSubmit: - """Test setup completion.""" - - @pytest.mark.asyncio - async def test_marks_setup_complete_and_redirects(self): - """POST /setup/finish marks setup_complete=true and redirects to /.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - # Mock form with CSRF token - form_data = MagicMock() - form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_finish_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/" - mock_conn.execute.assert_called_once() - mock_audit.assert_called_once() - class TestSetupGateMiddlewareWizard: """Test SetupGateMiddleware with wizard paths.""" @@ -570,69 +169,6 @@ class TestSetupGateMiddlewareWizard: response = client.get("/setup/operator") assert response.status_code == 200 - @pytest.mark.asyncio - async def test_redirects_base_setup_to_wizard_step(self): - """SetupGateMiddleware redirects /setup to appropriate wizard step.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.fetchval = AsyncMock(return_value=0) # No operators - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/setup") - async def setup(): - return {"message": "base setup"} - - @app.get("/setup/operator") - async def setup_operator(): - return {"message": "operator"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/setup") - assert response.status_code == 302 - assert response.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_redirects_login_to_setup_when_incomplete(self): - """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/login") - async def login(): - return {"message": "login"} - - @app.get("/setup") - async def setup(): - return {"message": "setup"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/login") - assert response.status_code == 302 - assert response.headers["location"] == "/setup" - @pytest.mark.asyncio async def test_redirects_all_setup_paths_when_complete(self): """SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""