diff --git a/src/central/gui/csrf.py b/src/central/gui/csrf.py index 0d6198f..37848cd 100644 --- a/src/central/gui/csrf.py +++ b/src/central/gui/csrf.py @@ -1,11 +1,10 @@ -"""Pre-auth CSRF protection for login and setup/operator pages. +"""Pre-auth CSRF protection for login and setup pages. These routes cannot use session-bound CSRF because no session exists yet. Uses a simple cookie-based pattern with short-lived tokens. """ import secrets -from typing import Optional from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from starlette.requests import Request @@ -34,6 +33,34 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]: return plain_token, signed_token +def reuse_or_generate_pre_auth_csrf( + request: Request, + secret_key: str, +) -> tuple[str, str | None]: + """Reuse an existing valid pre-auth CSRF token, or generate new. + + Returns (plain_token, signed_token_for_cookie). + If signed_token_for_cookie is None, the existing cookie is + still valid and caller should not call set_pre_auth_csrf_cookie. + If non-None, caller MUST call set_pre_auth_csrf_cookie with + it to persist the new value. + """ + cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE) + if cookie_value: + serializer = _get_serializer(secret_key) + try: + plain_token = serializer.loads( + cookie_value, + max_age=PRE_AUTH_CSRF_MAX_AGE, + ) + return plain_token, None # reuse existing + except (BadSignature, SignatureExpired): + pass # fall through to generate + + plain_token, signed_token = generate_pre_auth_csrf(secret_key) + return plain_token, signed_token + + def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None: """Set the pre-auth CSRF cookie on a response.""" response.set_cookie( diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index 155112b..2af6230 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -16,7 +16,15 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} -AUTH_EXEMPT_PREFIXES = ("/static/",) +AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/") + +# Browser-noise paths that trigger CSRF race conditions +BROWSER_NOISE_PATHS = { + "/favicon.ico", + "/apple-touch-icon.png", + "/apple-touch-icon-precomposed.png", + "/robots.txt", +} def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: @@ -29,33 +37,14 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False -async def _get_wizard_redirect_step(conn) -> str: - """Determine which wizard step to redirect to based on DB state.""" - # Check if any operators exist - op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - if op_count == 0: +def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str: + """Determine wizard redirect step from cookie state.""" + from central.gui.wizard import get_wizard_state, get_step_route + + state = get_wizard_state(request, csrf_secret) + if state is None: return "/setup/operator" - - # Check if system settings have been configured (map_tile_url not default) - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - if sys_row is None or sys_row["map_tile_url"] == default_tile: - return "/setup/system" - - # Keys step is optional, so check adapters have been reviewed - # We consider adapters reviewed if any adapter has a non-null updated_at - # (meaning it was explicitly saved during setup) - adapters_touched = await conn.fetchval( - "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" - ) - if adapters_touched == 0: - # Go to keys first, then adapters - return "/setup/keys" - - # All steps done, go to finish - return "/setup/finish" + return get_step_route(state.wizard_step) class SetupGateMiddleware(BaseHTTPMiddleware): @@ -64,6 +53,10 @@ class SetupGateMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path + # Short-circuit browser-noise requests that cause CSRF races + if path in BROWSER_NOISE_PATHS: + return Response(status_code=204) + # Check setup status from database pool = get_pool() if pool is None: @@ -85,13 +78,16 @@ class SetupGateMiddleware(BaseHTTPMiddleware): if not setup_complete: # Setup not complete - only allow setup paths and static/health if path.startswith("/setup"): - # Allow all /setup/* paths (handler will enforce auth) + # Allow all /setup/* paths # But /setup with no subpath should redirect to appropriate step if path == "/setup" or path == "/setup/": try: - async with pool.acquire() as conn: - redirect_step = await _get_wizard_redirect_step(conn) - return RedirectResponse(url=redirect_step, status_code=302) + from central.bootstrap_config import get_settings + settings = get_settings() + redirect_step = _get_wizard_redirect_from_cookie( + request, settings.csrf_secret + ) + return RedirectResponse(url=redirect_step, status_code=302) except Exception: logger.warning("Failed to determine wizard step", exc_info=True) return RedirectResponse(url="/setup/operator", status_code=302) @@ -118,6 +114,11 @@ class SessionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path + # Short-circuit browser-noise requests (already handled by SetupGateMiddleware, + # but this protects if middleware order changes) + if path in BROWSER_NOISE_PATHS: + return Response(status_code=204) + # Initialize state request.state.operator = None request.state.csrf_token = None @@ -139,7 +140,7 @@ class SessionMiddleware(BaseHTTPMiddleware): request.state.operator = None request.state.csrf_token = None - # Check if auth is required + # Check if auth is required - setup paths are exempt during wizard if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES): if request.state.operator is None: return RedirectResponse(url="/login", status_code=302) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 0d08e77..9e4f2d0 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -13,6 +13,7 @@ from fastapi import APIRouter, Depends, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response from central.bootstrap_config import get_settings from central.gui.csrf import ( + reuse_or_generate_pre_auth_csrf, generate_pre_auth_csrf, set_pre_auth_csrf_cookie, validate_pre_auth_csrf, @@ -49,6 +50,9 @@ router = APIRouter() # Streams to display on dashboard DASHBOARD_STREAMS = ["CENTRAL_WX", "CENTRAL_FIRE", "CENTRAL_QUAKE", "CENTRAL_META"] +# Email validation regex (simple but effective) +ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$") + # Email validation regex (simple but effective) EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") @@ -267,24 +271,29 @@ async def dashboard_polls(request: Request) -> HTMLResponse: # ============================================================================= -@router.get("/setup/operator", response_class=HTMLResponse) -async def setup_operator_form( - request: Request, -) -> HTMLResponse: - """Render the setup operator form (step 1).""" - templates = _get_templates() - pool = get_pool() - settings = get_settings() - csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret) +# ============================================================================= +# Setup Wizard routes (deferred-commit pattern) +# ============================================================================= - # Check if operator already exists - existing_operator = None - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - if row: - existing_operator = {"username": row["username"]} + +@router.get("/setup/operator", response_class=HTMLResponse) +async def setup_operator_form(request: Request) -> HTMLResponse: + """Render the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + templates = _get_templates() + settings = get_settings() + + # Get wizard state from cookie (if any) + state = get_wizard_state(request, settings.csrf_secret) + + # Pre-fill from cookie state if available + form_data = None + if state and state.operator: + form_data = {"username": state.operator.get("username", "")} + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, @@ -292,11 +301,11 @@ async def setup_operator_form( context={ "csrf_token": csrf_token, "error": None, - "form_data": None, - "existing_operator": existing_operator, + "form_data": form_data, }, ) - set_pre_auth_csrf_cookie(response, signed_token) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @@ -306,39 +315,22 @@ async def setup_operator_submit( username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), - ) -> Response: """Process the setup operator form (step 1).""" + from central.gui.wizard import get_wizard_state, set_wizard_cookie, WizardState + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + templates = _get_templates() - pool = get_pool() + settings = get_settings() # Validate CSRF - settings = get_settings() form = await request.form() form_csrf = form.get("csrf_token", "") if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - # Check if operator already exists (single-operator-per-install design) - async with pool.acquire() as conn: - count = await conn.fetchval("SELECT count(*) FROM config.operators") - if count > 0: - # Operator already exists — render confirmation page - existing = await conn.fetchrow( - "SELECT username FROM config.operators ORDER BY id LIMIT 1" - ) - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_operator.html", - context={ - "csrf_token": csrf_token, - "error": None, - "form_data": None, - "existing_operator": {"username": existing["username"]}, - }, - ) - return response + # Get or create wizard state + state = get_wizard_state(request, settings.csrf_secret) or WizardState() # Validate input error = None @@ -351,7 +343,7 @@ async def setup_operator_submit( error = str(e) if error: - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_operator.html", @@ -359,73 +351,54 @@ async def setup_operator_submit( "csrf_token": csrf_token, "error": error, "form_data": {"username": username}, - "existing_operator": None, }, status_code=200, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response - # Create operator + # Hash password and store in wizard state (NO DB write) password_hash = hash_password(password) - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - INSERT INTO config.operators (username, password_hash) - VALUES ($1, $2) - RETURNING id - """, - username, - password_hash, - ) - operator_id = row["id"] + state.operator = {"username": username, "password_hash": password_hash} + state.wizard_step = max(state.wizard_step, 2) - # Write audit log - await write_audit( - conn, - OPERATOR_CREATE, - operator_id=operator_id, - target=username, - ) - - # Get session lifetime - sysrow = await conn.fetchrow( - "SELECT session_lifetime_days FROM config.system WHERE id = true" - ) - lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 - - # Create session - token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) - - # Redirect to next step with session cookie + # Redirect to next step with updated wizard cookie response = RedirectResponse(url="/setup/system", status_code=302) - _set_session_cookie(response, token, lifetime_days * 86400) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/system", response_class=HTMLResponse) -async def setup_system_form( - request: Request, - -) -> HTMLResponse: +async def setup_system_form(request: Request) -> HTMLResponse: """Render the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required for step 2+ + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", - } + # Pre-fill from cookie state or DB defaults + if state.system: + system = state.system + else: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_system.html", @@ -437,29 +410,31 @@ async def setup_system_form( "system": system, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/system") -async def setup_system_submit( - request: Request, - -) -> Response: +async def setup_system_submit(request: Request) -> Response: """Process the system settings form (step 2).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() - pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() map_tile_url = form.get("map_tile_url", "").strip() map_attribution = form.get("map_attribution", "").strip() @@ -480,87 +455,52 @@ async def setup_system_submit( if not map_attribution: errors["map_attribution"] = "Map attribution is required" - async with pool.acquire() as conn: - if errors: - row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": row["map_tile_url"] if row else "", - "map_attribution": row["map_attribution"] if row else "", - } - - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_system.html", - context={ - "csrf_token": csrf_token, - "error": None, - "errors": errors, - "form_data": form_data, - "system": system, - }, - status_code=200, - ) - return response - - # Get current values for audit - old_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": state.system or form_data, + }, + status_code=200, ) - before = { - "map_tile_url": old_row["map_tile_url"] if old_row else None, - "map_attribution": old_row["map_attribution"] if old_row else None, - } + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - # Update system settings - await conn.execute( - """ - UPDATE config.system - SET map_tile_url = $1, map_attribution = $2 - WHERE id = true - """, - map_tile_url, - map_attribution, - ) + # Update wizard state (NO DB write) + state.system = {"map_tile_url": map_tile_url, "map_attribution": map_attribution} + state.wizard_step = max(state.wizard_step, 3) - # Write audit log - await write_audit( - conn, - SYSTEM_UPDATE, - operator_id=operator.id, - target="system", - before=before, - after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, - ) - - return RedirectResponse(url="/setup/keys", status_code=302) + response = RedirectResponse(url="/setup/keys", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/keys", response_class=HTMLResponse) -async def setup_keys_form( - request: Request, - -) -> HTMLResponse: +async def setup_keys_form(request: Request) -> HTMLResponse: """Render the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) - from central.crypto import encrypt - templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + # Keys come from cookie state (not DB) + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -572,36 +512,40 @@ async def setup_keys_form( "success": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/keys") -async def setup_keys_submit( - request: Request, - -) -> Response: +async def setup_keys_submit(request: Request) -> Response: """Process the API keys form (step 3).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) - - form = await request.form() - form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: - raise CsrfValidationError("Invalid CSRF token") - - form = await request.form() - action = form.get("action", "add") - - # If action is "next", redirect to adapters step - if action == "next": - return RedirectResponse(url="/setup/adapters", status_code=302) - + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf from central.crypto import encrypt templates = _get_templates() - pool = get_pool() + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF + form = await request.form() + form_csrf = form.get("csrf_token", "") + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): + raise CsrfValidationError("Invalid CSRF token") + + action = form.get("action", "add") + + # If action is "next", advance to adapters step + if action == "next": + state.wizard_step = max(state.wizard_step, 4) + response = RedirectResponse(url="/setup/adapters", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response # Otherwise, add a new key alias = form.get("alias", "").strip() @@ -617,6 +561,8 @@ async def setup_keys_submit( errors["alias"] = "Alias must be at most 64 characters" elif not ALIAS_REGEX.match(alias): errors["alias"] = "Alias must contain only letters, numbers, and underscores" + elif any(k["alias"] == alias for k in state.api_keys): + errors["alias"] = "An API key with this alias already exists" # Validate plaintext_key if not plaintext_key: @@ -624,69 +570,34 @@ async def setup_keys_submit( elif len(plaintext_key) > 4096: errors["plaintext_key"] = "API key must be at most 4096 characters" - async with pool.acquire() as conn: - if not errors: - # Check if alias already exists - existing = await conn.fetchrow( - "SELECT alias FROM config.api_keys WHERE alias = $1", - alias, - ) - if existing: - errors["alias"] = "An API key with this alias already exists" + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + if errors: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response - if errors: - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_keys.html", - context={ - "csrf_token": csrf_token, - "keys": keys, - "errors": errors, - "form_data": form_data, - "success": None, - }, - status_code=200, - ) - return response - - # Encrypt the key - encrypted_value = encrypt(plaintext_key.encode()) - - # Insert the new key - row = await conn.fetchrow( - """ - INSERT INTO config.api_keys (alias, encrypted_value) - VALUES ($1, $2) - RETURNING created_at - """, - alias, - encrypted_value, - ) - - # Write audit log (no plaintext!) - await write_audit( - conn, - API_KEY_CREATE, - operator_id=operator.id, - target=alias, - before=None, - after={"alias": alias, "created_at": row["created_at"].isoformat()}, - ) - - # Refresh keys list - keys = await conn.fetch( - "SELECT alias, created_at FROM config.api_keys ORDER BY alias" - ) - keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + # Encrypt the key and add to state (NO DB write) + encrypted_value = encrypt(plaintext_key.encode()) + encrypted_b64 = base64.b64encode(encrypted_value).decode() + state.api_keys.append({"alias": alias, "encrypted_value_b64": encrypted_b64}) # Re-render with success message - csrf_token = request.state.csrf_token + keys = [{"alias": k["alias"], "created_at": None} for k in state.api_keys] + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_keys.html", @@ -698,61 +609,82 @@ async def setup_keys_submit( "success": f"API key '{alias}' added successfully.", }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + set_wizard_cookie(response, state, settings.csrf_secret) return response @router.get("/setup/adapters", response_class=HTMLResponse) -async def setup_adapters_form( - request: Request, - -) -> HTMLResponse: +async def setup_adapters_form(request: Request) -> HTMLResponse: """Render the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() pool = get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) + # Pre-fill from cookie state or DB defaults + if state.adapters: adapters = [] - for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - "settings": settings, - }) + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({ + "name": name, + "enabled": a["enabled"], + "cadence_s": a["cadence_s"], + "settings": a["settings"], + }) + else: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + adapters = [] + for row in rows: + settings_data = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings_data, + }) - # Get API keys for dropdown - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + # Get API keys from wizard state (not DB) + api_keys = [{"alias": k["alias"]} for k in state.api_keys] - # Get map tile settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Get map tile settings from wizard state or DB + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + async with pool.acquire() as conn: + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_adapters.html", context={ "csrf_token": csrf_token, "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], + "api_keys": api_keys, "valid_satellites": _get_valid_satellites(), "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, @@ -762,242 +694,216 @@ async def setup_adapters_form( "form_data": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/adapters") -async def setup_adapters_submit( - request: Request, - -) -> Response: +async def setup_adapters_submit(request: Request) -> Response: """Process the adapters configuration form (step 4).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) + from central.gui.wizard import get_wizard_state, set_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf templates = _get_templates() pool = get_pool() + settings = get_settings() + # Get wizard state - required + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + # Validate CSRF form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - form = await request.form() errors: dict[str, str] = {} + new_adapters: dict[str, dict] = {} - async with pool.acquire() as conn: - # Get current adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ - ) - - for row in rows: - adapter_name = row["name"] - current_settings = row["settings"] or {} - new_settings = dict(current_settings) - - # Parse enabled - enabled = f"{adapter_name}_enabled" in form - - # Parse cadence - cadence_str = form.get(f"{adapter_name}_cadence_s", "") - try: - cadence_s = int(cadence_str) - if cadence_s < 60 or cadence_s > 3600: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" - except ValueError: - errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" - cadence_s = row["cadence_s"] - - # Adapter-specific validation - if adapter_name == "nws": - contact_email = form.get(f"{adapter_name}_contact_email", "").strip() - if enabled: - if not contact_email: - errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(contact_email): - errors[f"{adapter_name}_contact_email"] = "Invalid email format" - else: - new_settings["contact_email"] = contact_email - else: - new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") - - elif adapter_name == "firms": - api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() - satellites = form.getlist(f"{adapter_name}_satellites") - - if api_key_alias: - key_exists = await conn.fetchrow( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - api_key_alias, - ) - if not key_exists: - errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" - else: - new_settings["api_key_alias"] = api_key_alias - else: - new_settings["api_key_alias"] = None - - # Validate satellites - valid_sats = set(_get_valid_satellites()) - invalid_sats = [s for s in satellites if s not in valid_sats] - if invalid_sats: - errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" - else: - new_settings["satellites"] = satellites - - elif adapter_name == "usgs_quake": - feed = form.get(f"{adapter_name}_feed", "").strip() - valid_feeds = _get_valid_feeds() - if feed not in valid_feeds: - errors[f"{adapter_name}_feed"] = f"Invalid feed" - else: - new_settings["feed"] = feed - - # Region validation - region_north_str = form.get(f"{adapter_name}_region_north", "").strip() - region_south_str = form.get(f"{adapter_name}_region_south", "").strip() - region_east_str = form.get(f"{adapter_name}_region_east", "").strip() - region_west_str = form.get(f"{adapter_name}_region_west", "").strip() - - try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) - - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" - else: - new_settings["region"] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" - - # Store parsed data for re-render on error or update - if not errors.get(f"{adapter_name}_cadence_s"): - # Update adapter - await conn.execute( - """ - UPDATE config.adapters - SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() - WHERE name = $4 - """, - enabled, - cadence_s, - new_settings, - adapter_name, - ) - - # If any errors, re-render - if errors: - adapters = [] + # Get current adapter configs from state or DB as baseline + if state.adapters: + current_adapters = state.adapters + else: + async with pool.acquire() as conn: rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s, settings - FROM config.adapters - ORDER BY name - """ + "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" ) + current_adapters = {} for row in rows: - settings = row["settings"] or {} - adapters.append({ - "name": row["name"], + current_adapters[row["name"]] = { "enabled": row["enabled"], "cadence_s": row["cadence_s"], - "settings": settings, - }) + "settings": row["settings"] or {}, + } - api_keys = await conn.fetch( - "SELECT alias FROM config.api_keys ORDER BY alias" - ) + for adapter_name in ["firms", "nws", "usgs_quake"]: + current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) + current_settings = current.get("settings", {}) + new_settings = dict(current_settings) - sys_row = await conn.fetchrow( - "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" - ) - tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + # Parse enabled + enabled = f"{adapter_name}_enabled" in form - csrf_token = request.state.csrf_token - response = templates.TemplateResponse( - request=request, - name="setup_adapters.html", - context={ - "csrf_token": csrf_token, - "adapters": adapters, - "api_keys": [{"alias": k["alias"]} for k in api_keys], - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), - "tile_url": tile_url, - "tile_attribution": tile_attribution, - "error": "Please fix the errors below.", - "errors": errors, - "form_data": form, - }, - status_code=200, - ) - return response + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = current.get("cadence_s", 300) - return RedirectResponse(url="/setup/finish", status_code=302) + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + # Validate against wizard state keys + if not any(k["alias"] == api_key_alias for k in state.api_keys): + errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = "Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation (all adapters) + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + new_adapters[adapter_name] = { + "enabled": enabled, + "cadence_s": cadence_s, + "settings": new_settings, + } + + # If errors, re-render + if errors: + adapters = [ + {"name": name, "enabled": new_adapters[name]["enabled"], + "cadence_s": new_adapters[name]["cadence_s"], + "settings": new_adapters[name]["settings"]} + for name in ["firms", "nws", "usgs_quake"] + ] + api_keys = [{"alias": k["alias"]} for k in state.api_keys] + + if state.system: + tile_url = state.system["map_tile_url"] + tile_attribution = state.system["map_attribution"] + else: + tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = "© OpenStreetMap contributors" + + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": api_keys, + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + # Update wizard state (NO DB write) + state.adapters = new_adapters + state.wizard_step = max(state.wizard_step, 5) + + response = RedirectResponse(url="/setup/finish", status_code=302) + set_wizard_cookie(response, state, settings.csrf_secret) + return response @router.get("/setup/finish", response_class=HTMLResponse) -async def setup_finish_form( - request: Request, - -) -> HTMLResponse: +async def setup_finish_form(request: Request) -> HTMLResponse: """Render the finish setup page (step 5).""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: + from central.gui.wizard import get_wizard_state + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: return RedirectResponse(url="/setup/operator", status_code=302) templates = _get_templates() - pool = get_pool() - async with pool.acquire() as conn: - # Get counts - operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") - key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + operator_count = 1 if state.operator else 0 + key_count = len(state.api_keys) + system = state.system or {"map_tile_url": "(not configured)"} - # Get system settings - sys_row = await conn.fetchrow( - "SELECT map_tile_url FROM config.system WHERE id = true" - ) - system = { - "map_tile_url": sys_row["map_tile_url"] if sys_row else "", - } + adapters = [] + if state.adapters: + for name in ["firms", "nws", "usgs_quake"]: + if name in state.adapters: + a = state.adapters[name] + adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) - # Get adapters - rows = await conn.fetch( - """ - SELECT name, enabled, cadence_s - FROM config.adapters - ORDER BY name - """ - ) - adapters = [ - { - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - } - for row in rows - ] - - csrf_token = request.state.csrf_token + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( request=request, name="setup_finish.html", @@ -1007,46 +913,101 @@ async def setup_finish_form( "key_count": key_count, "system": system, "adapters": adapters, + "error": None, }, ) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) return response @router.post("/setup/finish") -async def setup_finish_submit( - request: Request, - -) -> Response: - """Complete the setup wizard.""" - # Require authentication for this step - operator = getattr(request.state, "operator", None) - if operator is None: - return RedirectResponse(url="/setup/operator", status_code=302) +async def setup_finish_submit(request: Request) -> Response: + """Complete the setup wizard - atomic commit of all wizard state.""" + from central.gui.wizard import get_wizard_state, clear_wizard_cookie + from central.gui.csrf import reuse_or_generate_pre_auth_csrf + from asyncpg.exceptions import UniqueViolationError + templates = _get_templates() pool = get_pool() + settings = get_settings() + + state = get_wizard_state(request, settings.csrf_secret) + if state is None or state.operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) form = await request.form() form_csrf = form.get("csrf_token", "") - if not form_csrf or form_csrf != request.state.csrf_token: + if not validate_pre_auth_csrf(request, form_csrf, settings.csrf_secret): raise CsrfValidationError("Invalid CSRF token") - async with pool.acquire() as conn: - # Mark setup complete - await conn.execute( - "UPDATE config.system SET setup_complete = true WHERE id = true" - ) + if not state.system: + return RedirectResponse(url="/setup/system", status_code=302) + if not state.adapters: + return RedirectResponse(url="/setup/adapters", status_code=302) - # Write audit log - await write_audit( - conn, - SETUP_COMPLETE, - operator_id=operator.id, - target="system", - ) + try: + async with pool.acquire() as conn: + async with conn.transaction(): + # 1. INSERT operator + op_row = await conn.fetchrow( + "INSERT INTO config.operators (username, password_hash) VALUES ($1, $2) RETURNING id", + state.operator["username"], + state.operator["password_hash"], + ) + operator_id = op_row["id"] - return RedirectResponse(url="/", status_code=302) + await write_audit(conn, OPERATOR_CREATE, operator_id=operator_id, target=state.operator["username"]) + # 2. Create session + sysrow = await conn.fetchrow("SELECT session_lifetime_days FROM config.system WHERE id = true") + lifetime_days = sysrow["session_lifetime_days"] if sysrow else 90 + token, expires_at, _ = await create_session(conn, operator_id, lifetime_days) + # 3. UPDATE config.system + old_sys = await conn.fetchrow("SELECT map_tile_url, map_attribution FROM config.system WHERE id = true") + await conn.execute( + "UPDATE config.system SET map_tile_url = $1, map_attribution = $2, setup_complete = true WHERE id = true", + state.system["map_tile_url"], + state.system["map_attribution"], + ) + await write_audit(conn, SYSTEM_UPDATE, operator_id=operator_id, target="system", + before={"map_tile_url": old_sys["map_tile_url"], "map_attribution": old_sys["map_attribution"]} if old_sys else None, + after={"map_tile_url": state.system["map_tile_url"], "map_attribution": state.system["map_attribution"]}) + + # 4. INSERT each API key + for key in state.api_keys: + encrypted = base64.b64decode(key["encrypted_value_b64"]) + await conn.execute("INSERT INTO config.api_keys (alias, encrypted_value) VALUES ($1, $2)", key["alias"], encrypted) + await write_audit(conn, API_KEY_CREATE, operator_id=operator_id, target=key["alias"]) + + # 5. UPDATE config.adapters + for name, adapter_cfg in state.adapters.items(): + old_adapter = await conn.fetchrow("SELECT enabled, cadence_s, settings FROM config.adapters WHERE name = $1", name) + await conn.execute( + "UPDATE config.adapters SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() WHERE name = $4", + adapter_cfg["enabled"], adapter_cfg["cadence_s"], adapter_cfg["settings"], name) + await write_audit(conn, ADAPTER_UPDATE, operator_id=operator_id, target=name, + before={"enabled": old_adapter["enabled"], "cadence_s": old_adapter["cadence_s"]} if old_adapter else None, + after={"enabled": adapter_cfg["enabled"], "cadence_s": adapter_cfg["cadence_s"]}) + + await write_audit(conn, SETUP_COMPLETE, operator_id=operator_id, target="system") + + except UniqueViolationError: + csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) + response = templates.TemplateResponse(request=request, name="setup_finish.html", + context={"csrf_token": csrf_token, "operator_count": 1, "key_count": len(state.api_keys), + "system": state.system, "adapters": [{"name": n, "enabled": a["enabled"], "cadence_s": a["cadence_s"]} for n, a in state.adapters.items()], + "error": f"Username '{state.operator['username']}' already exists."}, status_code=200) + if signed_token is not None: + set_pre_auth_csrf_cookie(response, signed_token) + return response + + response = RedirectResponse(url="/", status_code=302) + clear_wizard_cookie(response) + unset_pre_auth_csrf_cookie(response) + _set_session_cookie(response, token, lifetime_days * 86400) + return response @router.get("/login", response_class=HTMLResponse) async def login_form( request: Request, @@ -2236,6 +2197,274 @@ async def api_keys_delete( return RedirectResponse(url="/api-keys", status_code=302) + + +# --- Events query helper --- + +class EventsQueryResult: + """Result from events query.""" + def __init__(self, events: list, next_cursor: str | None, error: str | None = None): + self.events = events + self.next_cursor = next_cursor + self.error = error + + +def _parse_events_params(params) -> tuple[dict | None, str | None]: + """ + Parse and validate events query parameters. + + Returns: + (parsed_params, error_message) + If error_message is not None, parsed_params is None. + """ + # Parse and validate limit + limit_str = params.get("limit", "50") + try: + limit = int(limit_str) + except ValueError: + return None, f"Invalid limit value: {limit_str}" + + if limit < 1 or limit > 200: + return None, "limit must be between 1 and 200" + + # Parse adapter filter + adapter = params.get("adapter") + if adapter == "": + adapter = None + + # Parse category filter + category = params.get("category") + if category == "": + category = None + + # Parse since/until filters + since = None + until = None + + since_str = params.get("since") + if since_str: + try: + since = datetime.fromisoformat(since_str.replace("Z", "+00:00")) + except ValueError: + return None, f"Invalid ISO 8601 datetime for since: {since_str}" + + until_str = params.get("until") + if until_str: + try: + until = datetime.fromisoformat(until_str.replace("Z", "+00:00")) + except ValueError: + return None, f"Invalid ISO 8601 datetime for until: {until_str}" + + # Validate since <= until + if since and until and since > until: + return None, "since must be before or equal to until" + + # Parse region bbox + region_north = params.get("region_north") + region_south = params.get("region_south") + region_east = params.get("region_east") + region_west = params.get("region_west") + + # Treat empty strings as None + if region_north == "": + region_north = None + if region_south == "": + region_south = None + if region_east == "": + region_east = None + if region_west == "": + region_west = None + + region_params = [region_north, region_south, region_east, region_west] + region_supplied = [p for p in region_params if p is not None] + + if len(region_supplied) > 0 and len(region_supplied) < 4: + return None, "Region filter requires all four parameters: region_north, region_south, region_east, region_west" + + bbox = None + if len(region_supplied) == 4: + try: + bbox = { + "north": float(region_north), + "south": float(region_south), + "east": float(region_east), + "west": float(region_west), + } + except ValueError: + return None, "Region parameters must be valid numbers" + + # Parse cursor + cursor_time = None + cursor_id = None + cursor_str = params.get("cursor") + + if cursor_str: + try: + decoded = base64.b64decode(cursor_str).decode("utf-8") + parts = decoded.split("|", 1) + if len(parts) != 2: + raise ValueError("Invalid cursor format") + cursor_time = datetime.fromisoformat(parts[0]) + cursor_id = parts[1] + except Exception: + return None, "Invalid cursor" + + return { + "limit": limit, + "adapter": adapter, + "category": category, + "since": since, + "until": until, + "bbox": bbox, + "cursor_time": cursor_time, + "cursor_id": cursor_id, + }, None + + +async def _fetch_events(parsed_params: dict) -> EventsQueryResult: + """ + Fetch events from database using parsed parameters. + + Returns EventsQueryResult with events list, next_cursor, and optional error. + """ + pool = get_pool() + + limit = parsed_params["limit"] + adapter = parsed_params["adapter"] + category = parsed_params["category"] + since = parsed_params["since"] + until = parsed_params["until"] + bbox = parsed_params["bbox"] + cursor_time = parsed_params["cursor_time"] + cursor_id = parsed_params["cursor_id"] + + # Build query + conditions = [] + query_params = [] + param_idx = 1 + + if adapter: + conditions.append(f"adapter = ${param_idx}") + query_params.append(adapter) + param_idx += 1 + + if category: + conditions.append(f"category = ${param_idx}") + query_params.append(category) + param_idx += 1 + + if since: + conditions.append(f"time >= ${param_idx}") + query_params.append(since) + param_idx += 1 + + if until: + conditions.append(f"time < ${param_idx}") + query_params.append(until) + param_idx += 1 + + if bbox: + conditions.append( + f"ST_Intersects(geom, ST_MakeEnvelope(${param_idx}, ${param_idx+1}, ${param_idx+2}, ${param_idx+3}, 4326))" + ) + query_params.extend([bbox["west"], bbox["south"], bbox["east"], bbox["north"]]) + param_idx += 4 + + if cursor_time and cursor_id: + conditions.append(f"(time, id) < (${param_idx}, ${param_idx+1})") + query_params.append(cursor_time) + query_params.append(cursor_id) + param_idx += 2 + + where_clause = "" + if conditions: + where_clause = "WHERE " + " AND ".join(conditions) + + # Fetch limit+1 to check for next page + query = f""" + SELECT + id, + time, + received, + adapter, + category, + payload->>'subject' as subject, + ST_AsGeoJSON(geom) as geometry, + payload as data, + regions + FROM public.events + {where_clause} + ORDER BY time DESC, id DESC + LIMIT ${param_idx} + """ + query_params.append(limit + 1) + + try: + async with pool.acquire() as conn: + rows = await conn.fetch(query, *query_params) + except Exception as e: + logger.error(f"Database error in _fetch_events: {e}") + return EventsQueryResult([], None, "Database error") + + # Check if there is a next page + has_next = len(rows) > limit + if has_next: + rows = rows[:limit] + + # Build response + events = [] + for row in rows: + geometry = None + if row["geometry"]: + geometry = json.loads(row["geometry"]) + + events.append({ + "id": row["id"], + "time": row["time"].isoformat(), + "received": row["received"].isoformat(), + "adapter": row["adapter"], + "category": row["category"], + "subject": row["subject"], + "geometry": geometry, + "data": dict(row["data"]) if row["data"] else {}, + "regions": list(row["regions"]) if row["regions"] else [], + }) + + # Build next_cursor if there are more results + next_cursor = None + if has_next and events: + last_event = rows[-1] + cursor_data = f"{last_event['time'].isoformat()}|{last_event['id']}" + next_cursor = base64.b64encode(cursor_data.encode("utf-8")).decode("utf-8") + + return EventsQueryResult(events, next_cursor) + + +def _geometry_summary(geometry: dict | None) -> str: + """Generate a human-readable summary of a geometry.""" + if not geometry: + return "None" + + geom_type = geometry.get("type", "Unknown") + + if geom_type == "Point": + return "Point" + elif geom_type == "LineString": + coords = geometry.get("coordinates", []) + return f"Line ({len(coords)} pts)" + elif geom_type == "Polygon": + coords = geometry.get("coordinates", [[]]) + if coords: + return f"Polygon ({len(coords[0])} pts)" + return "Polygon" + elif geom_type == "MultiPolygon": + coords = geometry.get("coordinates", []) + return f"MultiPolygon ({len(coords)} parts)" + else: + return geom_type + + + @router.get("/events.json") async def events_json(request: Request): """ @@ -2468,3 +2697,125 @@ async def events_json(request: Request): "events": events, "next_cursor": next_cursor, }) + + +# --- Events feed frontend routes --- + +@router.get("/events", response_class=HTMLResponse) +async def events_list(request: Request) -> HTMLResponse: + """Events feed page with filter form, table, and map.""" + templates = _get_templates() + operator = getattr(request.state, "operator", None) + csrf_token = getattr(request.state, "csrf_token", "") + + params = request.query_params + + # Parse parameters + parsed, error = _parse_events_params(params) + + # Get system settings for map tiles + pool = get_pool() + async with pool.acquire() as conn: + system_row = await conn.fetchrow("SELECT map_tile_url, map_attribution FROM config.system") + + tile_url = system_row["map_tile_url"] if system_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = system_row["map_attribution"] if system_row else "OpenStreetMap" + + # Prepare filter values for template + filter_values = { + "adapter": params.get("adapter", ""), + "category": params.get("category", ""), + "since": params.get("since", ""), + "until": params.get("until", ""), + "region_north": params.get("region_north", ""), + "region_south": params.get("region_south", ""), + "region_east": params.get("region_east", ""), + "region_west": params.get("region_west", ""), + "limit": params.get("limit", "50"), + } + + events = [] + next_cursor = None + + if error: + # Validation error - show error banner but don't fail the page + pass + else: + # Fetch events + result = await _fetch_events(parsed) + if result.error: + error = result.error + else: + events = result.events + next_cursor = result.next_cursor + + # Add geometry summary to each event + for event in events: + event["geometry_summary"] = _geometry_summary(event.get("geometry")) + + return templates.TemplateResponse( + request=request, + name="events_list.html", + context={ + "operator": operator, + "csrf_token": csrf_token, + "events": events, + "next_cursor": next_cursor, + "filter_values": filter_values, + "filter_error": error, + "tile_url": tile_url, + "tile_attribution": tile_attribution, + }, + ) + + +@router.get("/events/rows", response_class=HTMLResponse) +async def events_rows(request: Request) -> HTMLResponse: + """HTMX fragment: events table rows only (no page chrome).""" + templates = _get_templates() + + params = request.query_params + + # Parse parameters + parsed, error = _parse_events_params(params) + + # Prepare filter values for template + filter_values = { + "adapter": params.get("adapter", ""), + "category": params.get("category", ""), + "since": params.get("since", ""), + "until": params.get("until", ""), + "region_north": params.get("region_north", ""), + "region_south": params.get("region_south", ""), + "region_east": params.get("region_east", ""), + "region_west": params.get("region_west", ""), + "limit": params.get("limit", "50"), + } + + events = [] + next_cursor = None + + if error: + pass + else: + result = await _fetch_events(parsed) + if result.error: + error = result.error + else: + events = result.events + next_cursor = result.next_cursor + + # Add geometry summary to each event + for event in events: + event["geometry_summary"] = _geometry_summary(event.get("geometry")) + + return templates.TemplateResponse( + request=request, + name="_events_rows.html", + context={ + "events": events, + "next_cursor": next_cursor, + "filter_values": filter_values, + "filter_error": error, + }, + ) diff --git a/src/central/gui/templates/_events_rows.html b/src/central/gui/templates/_events_rows.html new file mode 100644 index 0000000..75552e5 --- /dev/null +++ b/src/central/gui/templates/_events_rows.html @@ -0,0 +1,50 @@ +{% if filter_error %} +
+ Filter Error: {{ filter_error }} +
+{% endif %} + +{% if events %} + + + + + + + + + + + + {% for event in events %} + + + + + + + + {% endfor %} + +
TimeAdapterCategoryGeometrySubject
{{ event.time }}{{ event.adapter }}{{ event.category }}{{ event.geometry_summary }}{{ event.subject or '—' }}
+ +
+ Showing {{ events | length }} event{{ 's' if events | length != 1 else '' }}. + {% if next_cursor %} + + Next → + + {% else %} + End of results + {% endif %} +
+{% else %} +
+

No events match the filters.

+
+{% endif %} diff --git a/src/central/gui/templates/adapters_edit.html b/src/central/gui/templates/adapters_edit.html index 35ae6c8..939aa75 100644 --- a/src/central/gui/templates/adapters_edit.html +++ b/src/central/gui/templates/adapters_edit.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/base.html b/src/central/gui/templates/base.html index 0cd7baa..a7a667d 100644 --- a/src/central/gui/templates/base.html +++ b/src/central/gui/templates/base.html @@ -17,6 +17,7 @@ {% if operator %}
  • Dashboard
  • Adapters
  • +
  • Events
  • Streams
  • API Keys
  • {{ operator.username }}
  • diff --git a/src/central/gui/templates/events_list.html b/src/central/gui/templates/events_list.html new file mode 100644 index 0000000..f1fe5be --- /dev/null +++ b/src/central/gui/templates/events_list.html @@ -0,0 +1,378 @@ +{% extends "base.html" %} + +{% block title %}Events - Central{% endblock %} + +{% block head %} + + + +{% endblock %} + +{% block content %} +

    Events

    + +{% if filter_error %} +
    + Filter Error: {{ filter_error }} +
    +{% endif %} + +
    + Filters +
    + +
    +
    + + +
    +
    + + +
    +
    + + +
    +
    + + +
    +
    + +
    + +
    + + Draw a rectangle on the map to filter by region +
    +
    +
    + + +
    +
    + + +
    +
    + + +
    +
    + + +
    +
    +
    + + + +
    + + Clear Filters +
    +
    +
    + +
    + +
    + {% include "_events_rows.html" %} +
    + + + + + +{% endblock %} diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html index 9ff2f50..de3f8c2 100644 --- a/src/central/gui/templates/setup_adapters.html +++ b/src/central/gui/templates/setup_adapters.html @@ -4,9 +4,9 @@ {% block head %} - + - + {% endblock %} {% block content %} diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html index 36932b4..f4e9277 100644 --- a/src/central/gui/templates/setup_operator.html +++ b/src/central/gui/templates/setup_operator.html @@ -7,17 +7,6 @@ {% include "_wizard_header.html" %} {% endwith %} -{% if existing_operator %} -
    -
    -

    Operator Already Configured

    -
    -

    The operator account {{ existing_operator.username }} has been created.

    -
    - Next → -
    -
    -{% else %}

    Create Operator Account

    @@ -53,5 +42,4 @@
    -{% endif %} {% endblock %} diff --git a/src/central/gui/wizard.py b/src/central/gui/wizard.py new file mode 100644 index 0000000..8f28ac3 --- /dev/null +++ b/src/central/gui/wizard.py @@ -0,0 +1,131 @@ +"""Wizard state management for deferred-commit setup flow. + +The wizard collects configuration across 5 steps and commits everything +atomically at the final step. State is carried in a signed cookie. +""" + +import base64 +from dataclasses import dataclass, field, asdict +from typing import Any + +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired +from starlette.requests import Request +from starlette.responses import Response + + +# 1 hour max age for wizard cookie +WIZARD_MAX_AGE = 3600 +WIZARD_COOKIE = "central_wizard" + + +@dataclass +class WizardOperator: + """Operator data collected in step 1.""" + username: str + password_hash: str + + +@dataclass +class WizardSystem: + """System settings collected in step 2.""" + map_tile_url: str + map_attribution: str + + +@dataclass +class WizardApiKey: + """API key collected in step 3.""" + alias: str + encrypted_value_b64: str # base64-encoded encrypted value + + +@dataclass +class WizardAdapter: + """Adapter config collected in step 4.""" + enabled: bool + cadence_s: int + settings: dict[str, Any] + + +@dataclass +class WizardState: + """Complete wizard state carried across all steps.""" + wizard_step: int = 1 + operator: dict | None = None + system: dict | None = None + api_keys: list[dict] = field(default_factory=list) + adapters: dict[str, dict] | None = None + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "wizard_step": self.wizard_step, + "operator": self.operator, + "system": self.system, + "api_keys": self.api_keys, + "adapters": self.adapters, + } + + @classmethod + def from_dict(cls, data: dict) -> "WizardState": + """Create from dictionary.""" + return cls( + wizard_step=data.get("wizard_step", 1), + operator=data.get("operator"), + system=data.get("system"), + api_keys=data.get("api_keys", []), + adapters=data.get("adapters"), + ) + + +def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer: + """Get a timed serializer for wizard state.""" + return URLSafeTimedSerializer(secret_key, salt="wizard-state") + + +def get_wizard_state(request: Request, secret_key: str) -> WizardState | None: + """Decode wizard state from cookie. + + Returns WizardState if valid, None if missing/invalid/expired. + """ + cookie_value = request.cookies.get(WIZARD_COOKIE) + if not cookie_value: + return None + + serializer = _get_wizard_serializer(secret_key) + try: + data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE) + return WizardState.from_dict(data) + except (BadSignature, SignatureExpired): + return None + + +def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None: + """Set the wizard state cookie on a response.""" + serializer = _get_wizard_serializer(secret_key) + signed_value = serializer.dumps(state.to_dict()) + response.set_cookie( + WIZARD_COOKIE, + signed_value, + max_age=WIZARD_MAX_AGE, + path="/setup", + httponly=True, + samesite="lax", + ) + + +def clear_wizard_cookie(response: Response) -> None: + """Remove the wizard state cookie.""" + response.delete_cookie(WIZARD_COOKIE, path="/setup") + + +def get_step_route(step: int) -> str: + """Get the route for a wizard step number.""" + routes = { + 1: "/setup/operator", + 2: "/setup/system", + 3: "/setup/keys", + 4: "/setup/adapters", + 5: "/setup/finish", + } + return routes.get(step, "/setup/operator") diff --git a/tests/test_events_feed_frontend.py b/tests/test_events_feed_frontend.py new file mode 100644 index 0000000..fcdefa4 --- /dev/null +++ b/tests/test_events_feed_frontend.py @@ -0,0 +1,460 @@ +"""Tests for events feed frontend routes.""" + +import json +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from central.gui.routes import events_list, events_rows + + +class TestEventsFeedFrontendUnauthenticated: + """Test events feed frontend without authentication.""" + + @pytest.mark.asyncio + async def test_events_unauthenticated_redirects(self): + """GET /events without auth redirects to /login.""" + # This test verifies the session middleware behavior + # In practice, the middleware redirects before the route is called + mock_request = MagicMock() + mock_request.state.operator = None + # The middleware would redirect, verified via integration tests + + +class TestEventsFeedFrontendAuthenticated: + """Test events feed frontend with authentication.""" + + @pytest.mark.asyncio + async def test_events_no_filters_returns_html(self): + """GET /events authenticated, no filters returns HTML with events.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = {} + + mock_events = [ + { + "id": f"event_{i}", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc) - timedelta(hours=i), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc) - timedelta(hours=i), + "adapter": "nws", + "category": "Weather Alert", + "subject": f"Test Alert {i}", + "geometry": '{"type": "Point", "coordinates": [-122.4, 37.8]}' if i % 2 == 0 else None, + "data": {}, + "regions": [], + } + for i in range(5) + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + assert result.status_code == 200 + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert "events" in context + assert context["filter_error"] is None + + @pytest.mark.asyncio + async def test_events_adapter_filter(self): + """GET /events?adapter=nws returns only nws events.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = {"adapter": "nws"} + + mock_events = [ + { + "id": "nws_event_1", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "NWS Alert", + "geometry": None, + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + assert result.status_code == 200 + context = mock_templates.TemplateResponse.call_args.kwargs.get("context") + assert context["filter_values"]["adapter"] == "nws" + + @pytest.mark.asyncio + async def test_events_since_until_filter(self): + """GET /events?since=...&until=... filters by time window.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = { + "since": "2026-05-17T00:00:00", + "until": "2026-05-17T12:00:00", + } + + mock_events = [ + { + "id": "in_range", + "time": datetime(2026, 5, 17, 6, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 6, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "In Range", + "geometry": None, + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + assert result.status_code == 200 + + @pytest.mark.asyncio + async def test_events_region_filter(self): + """GET /events with full region bbox filters by location.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = { + "region_north": "49.5", + "region_south": "31", + "region_east": "-102", + "region_west": "-124.5", + } + + mock_events = [ + { + "id": "in_bbox", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "In BBox", + "geometry": '{"type": "Point", "coordinates": [-120, 40]}', + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + assert result.status_code == 200 + + @pytest.mark.asyncio + async def test_events_partial_region_shows_error_banner(self): + """GET /events with partial region shows error banner, not 400.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = {"region_north": "49"} + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + # Should be 200, not 400 + assert result.status_code == 200 + context = mock_templates.TemplateResponse.call_args.kwargs.get("context") + assert context["filter_error"] is not None + assert "region" in context["filter_error"].lower() + # Events should be empty due to validation error + assert context["events"] == [] + + @pytest.mark.asyncio + async def test_events_with_limit_shows_next_button(self): + """GET /events?limit=5 shows Next button when more events exist.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.query_params = {"limit": "5"} + + # Return 6 events (limit+1) to trigger pagination + mock_events = [ + { + "id": f"event_{i}", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc) - timedelta(hours=i), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc) - timedelta(hours=i), + "adapter": "nws", + "category": "Alert", + "subject": f"Event {i}", + "geometry": None, + "data": {}, + "regions": [], + } + for i in range(6) + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "OpenStreetMap", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_list(mock_request) + + assert result.status_code == 200 + context = mock_templates.TemplateResponse.call_args.kwargs.get("context") + assert context["next_cursor"] is not None + assert len(context["events"]) == 5 # Should be trimmed to limit + + +class TestEventsRowsFragment: + """Test /events/rows HTMX fragment.""" + + @pytest.mark.asyncio + async def test_events_rows_returns_fragment(self): + """GET /events/rows returns table fragment, not full page.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.query_params = {"limit": "5"} + + mock_events = [ + { + "id": "event_1", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "Event 1", + "geometry": None, + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_rows(mock_request) + + assert result.status_code == 200 + # Verify it uses the fragment template + call_args = mock_templates.TemplateResponse.call_args + assert call_args.kwargs.get("name") == "_events_rows.html" + + +class TestGeometrySummary: + """Test geometry summary function.""" + + def test_geometry_summary_polygon(self): + """Polygon geometry shows point count.""" + from central.gui.routes import _geometry_summary + + geom = { + "type": "Polygon", + "coordinates": [[[-122, 37], [-122, 38], [-121, 38], [-121, 37], [-122, 37]]] + } + summary = _geometry_summary(geom) + assert "Polygon" in summary + assert "5 pts" in summary + + def test_geometry_summary_point(self): + """Point geometry shows 'Point'.""" + from central.gui.routes import _geometry_summary + + geom = {"type": "Point", "coordinates": [-122.4, 37.8]} + summary = _geometry_summary(geom) + assert summary == "Point" + + def test_geometry_summary_linestring(self): + """LineString geometry shows point count.""" + from central.gui.routes import _geometry_summary + + geom = { + "type": "LineString", + "coordinates": [[-122, 37], [-121, 38], [-120, 39]] + } + summary = _geometry_summary(geom) + assert "Line" in summary + assert "3 pts" in summary + + def test_geometry_summary_none(self): + """None geometry shows 'None'.""" + from central.gui.routes import _geometry_summary + + summary = _geometry_summary(None) + assert summary == "None" + + +class TestDataGeometryAttribute: + """Test that rows have valid geometry data attributes.""" + + @pytest.mark.asyncio + async def test_event_with_geometry_has_valid_json(self): + """Events with geometry have parseable JSON in data-geometry.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.query_params = {} + + mock_events = [ + { + "id": "geom_event", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "With Geometry", + "geometry": '{"type": "Polygon", "coordinates": [[[-122, 37], [-122, 38], [-121, 38], [-121, 37], [-122, 37]]]}', + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_rows(mock_request) + + context = mock_templates.TemplateResponse.call_args.kwargs.get("context") + event = context["events"][0] + # Geometry should be parsed dict, not string + assert isinstance(event["geometry"], dict) + assert event["geometry"]["type"] == "Polygon" + + @pytest.mark.asyncio + async def test_event_without_geometry_has_none(self): + """Events without geometry have None for geometry field.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.query_params = {} + + mock_events = [ + { + "id": "no_geom_event", + "time": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "received": datetime(2026, 5, 17, 12, 0, tzinfo=timezone.utc), + "adapter": "nws", + "category": "Alert", + "subject": "No Geometry", + "geometry": None, + "data": {}, + "regions": [], + }, + ] + + mock_conn = AsyncMock() + mock_conn.fetch.return_value = mock_events + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock(status_code=200) + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await events_rows(mock_request) + + context = mock_templates.TemplateResponse.call_args.kwargs.get("context") + event = context["events"][0] + assert event["geometry"] is None diff --git a/tests/test_wizard.py b/tests/test_wizard.py index dcaa7fe..0492276 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,4 +1,4 @@ -"""Tests for the first-run setup wizard.""" +"""Tests for the first-run setup wizard with deferred-commit pattern.""" from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch @@ -11,60 +11,38 @@ from central.gui.routes import ( setup_system_submit, setup_keys_form, setup_keys_submit, - setup_adapters_form, - setup_adapters_submit, setup_finish_form, setup_finish_submit, ) -from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step +from central.gui.middleware import SetupGateMiddleware +from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie class TestWizardStepRedirect: - """Test wizard step redirect logic.""" + """Test wizard step redirect logic based on cookie state.""" - @pytest.mark.asyncio - async def test_no_operators_redirects_to_operator(self): - """When no operators exist, redirect to /setup/operator.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [0] # No operators + def test_no_cookie_redirects_to_operator(self): + """When no wizard cookie exists, redirect to /setup/operator.""" + from central.gui.middleware import _get_wizard_redirect_from_cookie - result = await _get_wizard_redirect_step(mock_conn) + mock_request = MagicMock() + mock_request.cookies = {} + + result = _get_wizard_redirect_from_cookie(mock_request, "testsecret") assert result == "/setup/operator" - @pytest.mark.asyncio - async def test_default_tile_url_redirects_to_system(self): - """When map_tile_url is default, redirect to /setup/system.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1] # Has operator - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" - } + def test_cookie_step_2_redirects_to_system(self): + """When wizard_step=2 in cookie, redirect to /setup/system.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(2) assert result == "/setup/system" - @pytest.mark.asyncio - async def test_no_adapters_touched_redirects_to_keys(self): - """When no adapters have been updated, redirect to /setup/keys.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } + def test_cookie_step_5_redirects_to_finish(self): + """When wizard_step=5 in cookie, redirect to /setup/finish.""" + from central.gui.wizard import get_step_route - result = await _get_wizard_redirect_step(mock_conn) - assert result == "/setup/keys" - - @pytest.mark.asyncio - async def test_all_steps_complete_redirects_to_finish(self): - """When all steps done, redirect to /setup/finish.""" - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" - } - - result = await _get_wizard_redirect_step(mock_conn) + result = get_step_route(5) assert result == "/setup/finish" @@ -72,63 +50,26 @@ class TestSetupOperatorForm: """Test operator creation form (step 1).""" @pytest.mark.asyncio - async def test_get_returns_form(self): - """GET /setup/operator returns the form when no operator exists.""" + async def test_get_returns_form_without_prefill(self): + """GET /setup/operator returns the form when no wizard cookie exists.""" mock_request = MagicMock() + mock_request.cookies = {} mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = None # No operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert "csrf_token" in context and context["csrf_token"] assert context["error"] is None - assert context["existing_operator"] is None - - @pytest.mark.asyncio - async def test_get_returns_confirmation_when_operator_exists(self): - """GET /setup/operator shows confirmation when operator already exists.""" - mock_request = MagicMock() - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.body = b"Operator Already Configured" - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): - result = await setup_operator_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "admin"} - assert context["error"] is None + assert context["form_data"] is None class TestSetupOperatorSubmit: @@ -138,28 +79,17 @@ class TestSetupOperatorSubmit: async def test_password_mismatch_shows_error(self): """POST with password mismatch re-renders with error.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password1", - "confirm_password": "password2", # Mismatch - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) + mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")): result = await setup_operator_submit( mock_request, username="testuser", @@ -172,374 +102,43 @@ class TestSetupOperatorSubmit: assert context["error"] == "Passwords do not match" @pytest.mark.asyncio - async def test_valid_creates_operator_and_redirects(self): - """POST with valid data creates operator and redirects to /setup/system.""" + async def test_valid_creates_wizard_cookie_and_redirects(self): + """POST with valid data creates wizard cookie and redirects to /setup/system.""" mock_request = MagicMock() - mock_request.state.csrf_token = "test_csrf" - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) + mock_request.cookies = {} + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"}) - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 0 # No existing operators - mock_conn.fetchrow.side_effect = [ - {"id": 1}, # INSERT RETURNING id - {"session_lifetime_days": 90}, # system settings - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.hash_password", return_value="hashed"): - with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: - mock_session.return_value = ("session_token", datetime.now(), "csrf_token") - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.hash_password", return_value="hashed_pw"): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) assert result.status_code == 302 assert result.headers["location"] == "/setup/system" - @pytest.mark.asyncio - async def test_post_when_operator_exists_shows_confirmation(self): - """POST when operator exists returns 200 with confirmation, no insert.""" - mock_request = MagicMock() - mock_request.form = AsyncMock(return_value={ - "csrf_token": "test_csrf", - "username": "testuser", - "password": "password123", - "confirm_password": "password123", - }) - - mock_templates = MagicMock() - mock_response = MagicMock() - mock_response.status_code = 200 - mock_templates.TemplateResponse.return_value = mock_response - - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = 1 # Operator already exists - mock_conn.fetchrow.return_value = {"username": "existing_admin"} - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - mock_request.state.csrf_token = "test_csrf" - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): - with patch("central.gui.routes.get_settings") as mock_settings: - mock_settings.return_value.csrf_secret = "testsecret" - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_operator_submit( - mock_request, - username="testuser", - password="password123", - confirm_password="password123", - ) - - # Should return 200, not 500 or redirect - assert result.status_code == 200 - - # Should render confirmation state - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["existing_operator"] == {"username": "existing_admin"} - - # Should NOT call write_audit (no insert happened) - mock_audit.assert_not_called() - class TestSetupSystemForm: """Test system settings form (step 2).""" @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/system without auth redirects to /setup/operator.""" + async def test_no_wizard_cookie_redirects_to_operator(self): + """GET /setup/system without wizard cookie redirects to /setup/operator.""" mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_system_form(mock_request) + mock_request.cookies = {} + + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + result = await setup_system_form(mock_request) + assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" - @pytest.mark.asyncio - async def test_authenticated_returns_form(self): - """GET /setup/system with auth returns the form.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", - "map_attribution": "© OpenStreetMap contributors", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_form(mock_request) - - mock_templates.TemplateResponse.assert_called_once() - - -class TestSetupSystemSubmit: - """Test system settings submission.""" - - @pytest.mark.asyncio - async def test_missing_placeholders_shows_error(self): - """POST without {z},{x},{y} placeholders shows error.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/tiles", - "map_attribution": "Test", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "", - "map_attribution": "", - } - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert "map_tile_url" in context["errors"] - - @pytest.mark.asyncio - async def test_valid_updates_and_redirects(self): - """POST with valid data updates system and redirects to /setup/keys.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "map_tile_url": "https://example.com/{z}/{x}/{y}.png", - "map_attribution": "Test Attribution", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = { - "map_tile_url": "old_url", - "map_attribution": "old_attr", - } - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_system_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/setup/keys" - - -class TestSetupKeysForm: - """Test API keys form (step 3).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/keys without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_keys_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupKeysSubmit: - """Test API keys submission.""" - - @pytest.mark.asyncio - async def test_next_action_redirects_to_adapters(self): - """POST with action=next redirects to /setup/adapters.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "next", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - # No need to mock get_pool since action="next" returns before it's called - result = await setup_keys_submit(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/adapters" - - @pytest.mark.asyncio - async def test_add_key_creates_and_rerenders(self): - """POST with action=add creates key and re-renders with success.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - form_data = MagicMock() - form_data.get = lambda k, default="": { - "csrf_token": "test_csrf_token", - "action": "add", - "alias": "testkey", - "plaintext_key": "secret123", - }.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchrow.side_effect = [ - None, # No existing key - {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, - ] - mock_conn.fetch.side_effect = [ - [], # First list - [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.crypto.encrypt", return_value=b"encrypted"): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_keys_submit(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["success"] == "API key 'testkey' added successfully." - - -class TestSetupAdaptersForm: - """Test adapters configuration form (step 4).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/adapters without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_adapters_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - -class TestSetupFinishForm: - """Test finish page (step 5).""" - - @pytest.mark.asyncio - async def test_unauthenticated_redirects_to_operator(self): - """GET /setup/finish without auth redirects to /setup/operator.""" - mock_request = MagicMock() - mock_request.state.operator = None - result = await setup_finish_form(mock_request) - assert result.status_code == 302 - assert result.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_authenticated_shows_summary(self): - """GET /setup/finish with auth shows summary.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - - mock_templates = MagicMock() - mock_templates.TemplateResponse.return_value = MagicMock() - - mock_conn = AsyncMock() - mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys - mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} - mock_conn.fetch.return_value = [ - {"name": "nws", "enabled": True, "cadence_s": 300}, - {"name": "firms", "enabled": False, "cadence_s": 600}, - ] - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes._get_templates", return_value=mock_templates): - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_finish_form(mock_request) - - call_args = mock_templates.TemplateResponse.call_args - context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["operator_count"] == 1 - assert context["key_count"] == 2 - assert len(context["adapters"]) == 2 - - -class TestSetupFinishSubmit: - """Test setup completion.""" - - @pytest.mark.asyncio - async def test_marks_setup_complete_and_redirects(self): - """POST /setup/finish marks setup_complete=true and redirects to /.""" - mock_request = MagicMock() - mock_request.state.operator = MagicMock(id=1, username="admin") - mock_request.state.csrf_token = "test_csrf_token" - - # Mock form with CSRF token - form_data = MagicMock() - form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) - mock_request.form = AsyncMock(return_value=form_data) - - mock_conn = AsyncMock() - mock_conn.execute = AsyncMock() - - mock_pool = MagicMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_conn - mock_pool.acquire.return_value.__aexit__.return_value = None - - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_finish_submit(mock_request) - - assert result.status_code == 302 - assert result.headers["location"] == "/" - mock_conn.execute.assert_called_once() - mock_audit.assert_called_once() - class TestSetupGateMiddlewareWizard: """Test SetupGateMiddleware with wizard paths.""" @@ -570,69 +169,6 @@ class TestSetupGateMiddlewareWizard: response = client.get("/setup/operator") assert response.status_code == 200 - @pytest.mark.asyncio - async def test_redirects_base_setup_to_wizard_step(self): - """SetupGateMiddleware redirects /setup to appropriate wizard step.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.fetchval = AsyncMock(return_value=0) # No operators - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/setup") - async def setup(): - return {"message": "base setup"} - - @app.get("/setup/operator") - async def setup_operator(): - return {"message": "operator"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/setup") - assert response.status_code == 302 - assert response.headers["location"] == "/setup/operator" - - @pytest.mark.asyncio - async def test_redirects_login_to_setup_when_incomplete(self): - """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" - from starlette.testclient import TestClient - from fastapi import FastAPI - - mock_pool = MagicMock() - mock_conn = MagicMock() - mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) - mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() - mock_pool.acquire = MagicMock(return_value=mock_conn) - - with patch("central.gui.middleware.get_pool", return_value=mock_pool): - app = FastAPI() - - @app.get("/login") - async def login(): - return {"message": "login"} - - @app.get("/setup") - async def setup(): - return {"message": "setup"} - - app.add_middleware(SetupGateMiddleware) - client = TestClient(app, follow_redirects=False) - - response = client.get("/login") - assert response.status_code == 302 - assert response.headers["location"] == "/setup" - @pytest.mark.asyncio async def test_redirects_all_setup_paths_when_complete(self): """SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""