From 62116ca6a4c8bdf3b6e2ee400b09ae9363858998 Mon Sep 17 00:00:00 2001 From: zvx-echo6 Date: Sun, 17 May 2026 19:06:23 -0600 Subject: [PATCH] feat(gui): implement first-run setup wizard (1b-8) Add a 5-step setup wizard that replaces the single-step /setup: 1. Create Operator - create initial operator account 2. System Settings - configure map tile URL and attribution 3. API Keys - optionally add API keys for adapters 4. Configure Adapters - enable/disable adapters with region picker 5. Finish Setup - review and complete setup Key changes: - Update middleware to handle wizard URL structure and step routing - Add wizard routes for each step with proper auth checks - Create new templates using base_wizard.html for consistent styling - Add audit events for system.update and setup.complete - Update tests for new middleware behavior Co-Authored-By: Claude Opus 4.5 --- src/central/gui/audit.py | 2 + src/central/gui/middleware.py | 60 +- src/central/gui/routes.py | 675 +++++++++++++++++- src/central/gui/templates/_wizard_header.html | 6 + src/central/gui/templates/base_wizard.html | 24 + src/central/gui/templates/setup_adapters.html | 217 ++++++ src/central/gui/templates/setup_finish.html | 69 ++ src/central/gui/templates/setup_keys.html | 84 +++ src/central/gui/templates/setup_operator.html | 45 ++ src/central/gui/templates/setup_system.html | 49 ++ tests/test_setup_gate.py | 50 +- tests/test_wizard.py | 586 +++++++++++++++ 12 files changed, 1840 insertions(+), 27 deletions(-) create mode 100644 src/central/gui/templates/_wizard_header.html create mode 100644 src/central/gui/templates/base_wizard.html create mode 100644 src/central/gui/templates/setup_adapters.html create mode 100644 src/central/gui/templates/setup_finish.html create mode 100644 src/central/gui/templates/setup_keys.html create mode 100644 src/central/gui/templates/setup_operator.html create mode 100644 src/central/gui/templates/setup_system.html create mode 100644 tests/test_wizard.py diff --git a/src/central/gui/audit.py b/src/central/gui/audit.py index 1bdb66b..b7cfd47 100644 --- a/src/central/gui/audit.py +++ b/src/central/gui/audit.py @@ -14,6 +14,8 @@ STREAM_UPDATE = "stream.update" API_KEY_CREATE = "api_key.create" API_KEY_ROTATE = "api_key.rotate" API_KEY_DELETE = "api_key.delete" +SYSTEM_UPDATE = "system.update" +SETUP_COMPLETE = "setup.complete" async def write_audit( diff --git a/src/central/gui/middleware.py b/src/central/gui/middleware.py index be5b25f..776554d 100644 --- a/src/central/gui/middleware.py +++ b/src/central/gui/middleware.py @@ -12,11 +12,10 @@ from central.gui.db import get_pool logger = logging.getLogger(__name__) # Paths that don't require setup to be complete -SETUP_EXEMPT_PATHS = {"/setup", "/health"} -SETUP_EXEMPT_PREFIXES = ("/static/",) +SETUP_EXEMPT_PREFIXES = ("/static/", "/setup") # Paths that don't require authentication -AUTH_EXEMPT_PATHS = {"/setup", "/login", "/health"} +AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"} AUTH_EXEMPT_PREFIXES = ("/static/",) @@ -30,6 +29,35 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool: return False +async def _get_wizard_redirect_step(conn) -> str: + """Determine which wizard step to redirect to based on DB state.""" + # Check if any operators exist + op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + if op_count == 0: + return "/setup/operator" + + # Check if system settings have been configured (map_tile_url not default) + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + if sys_row is None or sys_row["map_tile_url"] == default_tile: + return "/setup/system" + + # Keys step is optional, so check adapters have been reviewed + # We consider adapters reviewed if any adapter has a non-null updated_at + # (meaning it was explicitly saved during setup) + adapters_touched = await conn.fetchval( + "SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL" + ) + if adapters_touched == 0: + # Go to keys first, then adapters + return "/setup/keys" + + # All steps done, go to finish + return "/setup/finish" + + class SetupGateMiddleware(BaseHTTPMiddleware): """Redirect to /setup if setup is not complete.""" @@ -55,12 +83,30 @@ class SetupGateMiddleware(BaseHTTPMiddleware): return await call_next(request) if not setup_complete: - # Setup not complete - only allow exempt paths - if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES): + # Setup not complete - only allow setup paths and static/health + if path.startswith("/setup"): + # Allow all /setup/* paths (handler will enforce auth) + # But /setup with no subpath should redirect to appropriate step + if path == "/setup" or path == "/setup/": + try: + async with pool.acquire() as conn: + redirect_step = await _get_wizard_redirect_step(conn) + return RedirectResponse(url=redirect_step, status_code=302) + except Exception: + logger.warning("Failed to determine wizard step", exc_info=True) + return RedirectResponse(url="/setup/operator", status_code=302) + return await call_next(request) + elif path == "/health" or path.startswith("/static/"): + return await call_next(request) + elif path == "/login": + # During setup, login redirects to /setup + return RedirectResponse(url="/setup", status_code=302) + else: + # All other paths redirect to /setup return RedirectResponse(url="/setup", status_code=302) else: - # Setup complete - redirect /setup to / - if path == "/setup": + # Setup complete - redirect /setup* to / + if path.startswith("/setup"): return RedirectResponse(url="/", status_code=302) return await call_next(request) diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 37a5c37..1b8be2e 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -28,7 +28,9 @@ from central.gui.audit import ( AUTH_LOGOUT, AUTH_PASSWORD_CHANGE, OPERATOR_CREATE, + SETUP_COMPLETE, STREAM_UPDATE, + SYSTEM_UPDATE, write_audit, ) from central.gui.db import get_pool @@ -252,32 +254,37 @@ async def dashboard_polls(request: Request) -> HTMLResponse: ) -@router.get("/setup", response_class=HTMLResponse) -async def setup_form( +# ============================================================================= +# Setup Wizard routes +# ============================================================================= + + +@router.get("/setup/operator", response_class=HTMLResponse) +async def setup_operator_form( request: Request, csrf_protect: CsrfProtect = Depends(), ) -> HTMLResponse: - """Render the setup form.""" + """Render the setup operator form (step 1).""" templates = _get_templates() csrf_token, signed_token = csrf_protect.generate_csrf_tokens() response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": None}, + name="setup_operator.html", + context={"csrf_token": csrf_token, "error": None, "form_data": None}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response -@router.post("/setup") -async def setup_submit( +@router.post("/setup/operator") +async def setup_operator_submit( request: Request, username: str = Form(...), password: str = Form(...), confirm_password: str = Form(...), csrf_protect: CsrfProtect = Depends(), ) -> Response: - """Process the setup form.""" + """Process the setup operator form (step 1).""" templates = _get_templates() pool = get_pool() @@ -298,8 +305,12 @@ async def setup_submit( csrf_token, signed_token = csrf_protect.generate_csrf_tokens() response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": error}, + name="setup_operator.html", + context={ + "csrf_token": csrf_token, + "error": error, + "form_data": {"username": username}, + }, status_code=200, ) csrf_protect.set_csrf_cookie(signed_token, response) @@ -336,15 +347,651 @@ async def setup_submit( # Create session token, expires_at = await create_session(conn, operator_id, lifetime_days) + # Redirect to next step with session cookie + response = RedirectResponse(url="/setup/system", status_code=302) + _set_session_cookie(response, token, lifetime_days * 86400) + return response + + +@router.get("/setup/system", response_class=HTMLResponse) +async def setup_system_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the system settings form (step 2).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": row["map_attribution"] if row else "© OpenStreetMap contributors", + } + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": None, + "form_data": None, + "system": system, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup/system") +async def setup_system_submit( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the system settings form (step 2).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + await csrf_protect.validate_csrf(request) + + form = await request.form() + map_tile_url = form.get("map_tile_url", "").strip() + map_attribution = form.get("map_attribution", "").strip() + + form_data = { + "map_tile_url": map_tile_url, + "map_attribution": map_attribution, + } + + errors: dict[str, str] = {} + + # Validate map_tile_url + if not map_tile_url: + errors["map_tile_url"] = "Map tile URL is required" + elif "{z}" not in map_tile_url or "{x}" not in map_tile_url or "{y}" not in map_tile_url: + errors["map_tile_url"] = "URL must contain {z}, {x}, and {y} placeholders" + + # Validate map_attribution + if not map_attribution: + errors["map_attribution"] = "Map attribution is required" + + async with pool.acquire() as conn: + if errors: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": row["map_tile_url"] if row else "", + "map_attribution": row["map_attribution"] if row else "", + } + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": None, + "errors": errors, + "form_data": form_data, + "system": system, + }, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Get current values for audit + old_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + before = { + "map_tile_url": old_row["map_tile_url"] if old_row else None, + "map_attribution": old_row["map_attribution"] if old_row else None, + } + + # Update system settings + await conn.execute( + """ + UPDATE config.system + SET map_tile_url = $1, map_attribution = $2 + WHERE id = true + """, + map_tile_url, + map_attribution, + ) + + # Write audit log + await write_audit( + conn, + SYSTEM_UPDATE, + operator_id=operator.id, + target="system", + before=before, + after={"map_tile_url": map_tile_url, "map_attribution": map_attribution}, + ) + + return RedirectResponse(url="/setup/keys", status_code=302) + + +@router.get("/setup/keys", response_class=HTMLResponse) +async def setup_keys_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the API keys form (step 3).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + from central.crypto import encrypt + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": None, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup/keys") +async def setup_keys_submit( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the API keys form (step 3).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + await csrf_protect.validate_csrf(request) + + form = await request.form() + action = form.get("action", "add") + + # If action is "next", redirect to adapters step + if action == "next": + return RedirectResponse(url="/setup/adapters", status_code=302) + + from central.crypto import encrypt + + templates = _get_templates() + pool = get_pool() + + # Otherwise, add a new key + alias = form.get("alias", "").strip() + plaintext_key = form.get("plaintext_key", "") + + form_data = {"alias": alias} + errors: dict[str, str] = {} + + # Validate alias + if not alias: + errors["alias"] = "Alias is required" + elif len(alias) > 64: + errors["alias"] = "Alias must be at most 64 characters" + elif not ALIAS_REGEX.match(alias): + errors["alias"] = "Alias must contain only letters, numbers, and underscores" + + # Validate plaintext_key + if not plaintext_key: + errors["plaintext_key"] = "API key is required" + elif len(plaintext_key) > 4096: + errors["plaintext_key"] = "API key must be at most 4096 characters" + + async with pool.acquire() as conn: + if not errors: + # Check if alias already exists + existing = await conn.fetchrow( + "SELECT alias FROM config.api_keys WHERE alias = $1", + alias, + ) + if existing: + errors["alias"] = "An API key with this alias already exists" + + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + + if errors: + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": errors, + "form_data": form_data, + "success": None, + }, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + # Encrypt the key + encrypted_value = encrypt(plaintext_key.encode()) + + # Insert the new key + row = await conn.fetchrow( + """ + INSERT INTO config.api_keys (alias, encrypted_value) + VALUES ($1, $2) + RETURNING created_at + """, + alias, + encrypted_value, + ) + + # Write audit log (no plaintext!) + await write_audit( + conn, + API_KEY_CREATE, + operator_id=operator.id, + target=alias, + before=None, + after={"alias": alias, "created_at": row["created_at"].isoformat()}, + ) + + # Refresh keys list + keys = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in keys] + + # Re-render with success message + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": f"API key '{alias}' added successfully.", + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.get("/setup/adapters", response_class=HTMLResponse) +async def setup_adapters_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the adapters configuration form (step 4).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + adapters = [] + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + + # Get API keys for dropdown + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + + # Get map tile settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": [{"alias": k["alias"]} for k in api_keys], + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": None, + "errors": None, + "form_data": None, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup/adapters") +async def setup_adapters_submit( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Process the adapters configuration form (step 4).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + await csrf_protect.validate_csrf(request) + + form = await request.form() + errors: dict[str, str] = {} + + async with pool.acquire() as conn: + # Get current adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + + for row in rows: + adapter_name = row["name"] + current_settings = row["settings"] or {} + new_settings = dict(current_settings) + + # Parse enabled + enabled = f"{adapter_name}_enabled" in form + + # Parse cadence + cadence_str = form.get(f"{adapter_name}_cadence_s", "") + try: + cadence_s = int(cadence_str) + if cadence_s < 60 or cadence_s > 3600: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be between 60 and 3600 seconds" + except ValueError: + errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" + cadence_s = row["cadence_s"] + + # Adapter-specific validation + if adapter_name == "nws": + contact_email = form.get(f"{adapter_name}_contact_email", "").strip() + if enabled: + if not contact_email: + errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(contact_email): + errors[f"{adapter_name}_contact_email"] = "Invalid email format" + else: + new_settings["contact_email"] = contact_email + else: + new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + + elif adapter_name == "firms": + api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() + satellites = form.getlist(f"{adapter_name}_satellites") + + if api_key_alias: + key_exists = await conn.fetchrow( + "SELECT 1 FROM config.api_keys WHERE alias = $1", + api_key_alias, + ) + if not key_exists: + errors[f"{adapter_name}_api_key_alias"] = f"API key alias '{api_key_alias}' does not exist" + else: + new_settings["api_key_alias"] = api_key_alias + else: + new_settings["api_key_alias"] = None + + # Validate satellites + valid_sats = set(_get_valid_satellites()) + invalid_sats = [s for s in satellites if s not in valid_sats] + if invalid_sats: + errors[f"{adapter_name}_satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}" + else: + new_settings["satellites"] = satellites + + elif adapter_name == "usgs_quake": + feed = form.get(f"{adapter_name}_feed", "").strip() + valid_feeds = _get_valid_feeds() + if feed not in valid_feeds: + errors[f"{adapter_name}_feed"] = f"Invalid feed" + else: + new_settings["feed"] = feed + + # Region validation + region_north_str = form.get(f"{adapter_name}_region_north", "").strip() + region_south_str = form.get(f"{adapter_name}_region_south", "").strip() + region_east_str = form.get(f"{adapter_name}_region_east", "").strip() + region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_region"] = "Invalid latitude: south must be less than north, both between -90 and 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_region"] = "Invalid longitude: west must be less than east, both between -180 and 180" + else: + new_settings["region"] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + + # Store parsed data for re-render on error or update + if not errors.get(f"{adapter_name}_cadence_s"): + # Update adapter + await conn.execute( + """ + UPDATE config.adapters + SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now() + WHERE name = $4 + """, + enabled, + cadence_s, + new_settings, + adapter_name, + ) + + # If any errors, re-render + if errors: + adapters = [] + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s, settings + FROM config.adapters + ORDER BY name + """ + ) + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + + api_keys = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + tile_url = sys_row["map_tile_url"] if sys_row else "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = sys_row["map_attribution"] if sys_row else "© OpenStreetMap contributors" + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": [{"alias": k["alias"]} for k in api_keys], + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": "Please fix the errors below.", + "errors": errors, + "form_data": form, + }, + status_code=200, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + return RedirectResponse(url="/setup/finish", status_code=302) + + +@router.get("/setup/finish", response_class=HTMLResponse) +async def setup_finish_form( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> HTMLResponse: + """Render the finish setup page (step 5).""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + templates = _get_templates() + pool = get_pool() + + async with pool.acquire() as conn: + # Get counts + operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + + # Get system settings + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + system = { + "map_tile_url": sys_row["map_tile_url"] if sys_row else "", + } + + # Get adapters + rows = await conn.fetch( + """ + SELECT name, enabled, cadence_s + FROM config.adapters + ORDER BY name + """ + ) + adapters = [ + { + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + } + for row in rows + ] + + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( + request=request, + name="setup_finish.html", + context={ + "csrf_token": csrf_token, + "operator_count": operator_count, + "key_count": key_count, + "system": system, + "adapters": adapters, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + +@router.post("/setup/finish") +async def setup_finish_submit( + request: Request, + csrf_protect: CsrfProtect = Depends(), +) -> Response: + """Complete the setup wizard.""" + # Require authentication for this step + operator = getattr(request.state, "operator", None) + if operator is None: + return RedirectResponse(url="/setup/operator", status_code=302) + + pool = get_pool() + + await csrf_protect.validate_csrf(request) + + async with pool.acquire() as conn: # Mark setup complete await conn.execute( "UPDATE config.system SET setup_complete = true WHERE id = true" ) - # Redirect with session cookie - response = RedirectResponse(url="/", status_code=302) - _set_session_cookie(response, token, lifetime_days * 86400) - return response + # Write audit log + await write_audit( + conn, + SETUP_COMPLETE, + operator_id=operator.id, + target="system", + ) + + return RedirectResponse(url="/", status_code=302) @router.get("/login", response_class=HTMLResponse) diff --git a/src/central/gui/templates/_wizard_header.html b/src/central/gui/templates/_wizard_header.html new file mode 100644 index 0000000..941d18e --- /dev/null +++ b/src/central/gui/templates/_wizard_header.html @@ -0,0 +1,6 @@ +
+
+ Step {{ step }} of 5 — {{ step_name }} +
+ +
diff --git a/src/central/gui/templates/base_wizard.html b/src/central/gui/templates/base_wizard.html new file mode 100644 index 0000000..a3eacac --- /dev/null +++ b/src/central/gui/templates/base_wizard.html @@ -0,0 +1,24 @@ + + + + + + {% block title %}Central - Setup{% endblock %} + + + {% block head %}{% endblock %} + + + +
+ {% block content %}{% endblock %} +
+ + diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html new file mode 100644 index 0000000..0411e28 --- /dev/null +++ b/src/central/gui/templates/setup_adapters.html @@ -0,0 +1,217 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Configure Adapters{% endblock %} + +{% block head %} + + + + +{% endblock %} + +{% block content %} +{% with step=4, step_name="Configure Adapters" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

Configure Adapters

+

Enable and configure data source adapters. Each adapter polls an external API and normalizes events.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + {% for adapter in adapters %} +
+ {{ adapter.name }} + +
+ + {% if errors and errors.get(adapter.name + '_enabled') %} + {{ errors[adapter.name + '_enabled'] }} + {% endif %} + + + + {% if errors and errors.get(adapter.name + '_cadence_s') %} + {{ errors[adapter.name + '_cadence_s'] }} + {% endif %} + + {% if adapter.name == 'nws' %} + + + {% if errors and errors.get(adapter.name + '_contact_email') %} + {{ errors[adapter.name + '_contact_email'] }} + {% endif %} + {% endif %} + + {% if adapter.name == 'firms' %} + + + {% if errors and errors.get(adapter.name + '_api_key_alias') %} + {{ errors[adapter.name + '_api_key_alias'] }} + {% endif %} + + + {% for sat in valid_satellites %} + + {% endfor %} + {% endif %} + + {% if adapter.name == 'usgs_quake' %} + + + {% if errors and errors.get(adapter.name + '_feed') %} + {{ errors[adapter.name + '_feed'] }} + {% endif %} + {% endif %} + +

Region

+ {% set region = form_data if form_data else adapter.settings.region %} +
+ +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ {% if errors and errors.get(adapter.name + '_region') %} + {{ errors[adapter.name + '_region'] }} + {% endif %} +
+
+
+ {% endfor %} + +
+ ← Back + +
+
+
+ + +{% endblock %} diff --git a/src/central/gui/templates/setup_finish.html b/src/central/gui/templates/setup_finish.html new file mode 100644 index 0000000..7e0ac7e --- /dev/null +++ b/src/central/gui/templates/setup_finish.html @@ -0,0 +1,69 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Finish Setup{% endblock %} + +{% block content %} +{% with step=5, step_name="Finish Setup" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

Setup Complete

+

Review your configuration and finish the setup wizard.

+
+ +

Summary

+ + + + + + + + + + + + + + + + +
Operators{{ operator_count }} configured
API Keys{{ key_count }} configured
Map Tile URL{{ system.map_tile_url }}
+ +

Adapters

+ + + + + + + + + + {% for adapter in adapters %} + + + + + + {% endfor %} + +
AdapterStatusCadence
{{ adapter.name }} + {% if adapter.enabled %} + Enabled + {% else %} + Disabled + {% endif %} + {{ adapter.cadence_s }}s
+ +
+ +
+ ← Back + +
+
+
+{% endblock %} diff --git a/src/central/gui/templates/setup_keys.html b/src/central/gui/templates/setup_keys.html new file mode 100644 index 0000000..28457cc --- /dev/null +++ b/src/central/gui/templates/setup_keys.html @@ -0,0 +1,84 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - API Keys{% endblock %} + +{% block content %} +{% with step=3, step_name="API Keys" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

API Keys

+

Add API keys for adapters that require external service credentials (e.g., FIRMS).

+
+ + {% if success %} +

{{ success }}

+ {% endif %} + + {% if keys %} +

Existing Keys

+ + + + + + + + + {% for key in keys %} + + + + + {% endfor %} + +
AliasCreated
{{ key.alias }}{{ key.created_at.strftime('%Y-%m-%d %H:%M') if key.created_at else '(never)' }}
+ {% else %} +

No API keys configured yet.

+ {% endif %} + +

Add New Key

+
+ + + +
+
+ + + {% if errors and errors.alias %} + {{ errors.alias }} + {% else %} + Letters, numbers, and underscores only. + {% endif %} +
+
+ + + {% if errors and errors.plaintext_key %} + {{ errors.plaintext_key }} + {% else %} + Will be encrypted before storage. + {% endif %} +
+
+ + +
+ +
+ +
+ + +
+ ← Back + +
+
+
+{% endblock %} diff --git a/src/central/gui/templates/setup_operator.html b/src/central/gui/templates/setup_operator.html new file mode 100644 index 0000000..f4e9277 --- /dev/null +++ b/src/central/gui/templates/setup_operator.html @@ -0,0 +1,45 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - Create Operator{% endblock %} + +{% block content %} +{% with step=1, step_name="Create Operator" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

Create Operator Account

+

Create the initial operator account to manage Central.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + + + + + + + +
+
+{% endblock %} diff --git a/src/central/gui/templates/setup_system.html b/src/central/gui/templates/setup_system.html new file mode 100644 index 0000000..c49cb9c --- /dev/null +++ b/src/central/gui/templates/setup_system.html @@ -0,0 +1,49 @@ +{% extends "base_wizard.html" %} + +{% block title %}Central - System Settings{% endblock %} + +{% block content %} +{% with step=2, step_name="System Settings" %} +{% include "_wizard_header.html" %} +{% endwith %} + +
+
+

System Settings

+

Configure map tile provider for the region picker.

+
+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + + {% if errors and errors.map_tile_url %} + {{ errors.map_tile_url }} + {% endif %} + + + {% if errors and errors.map_attribution %} + {{ errors.map_attribution }} + {% endif %} + +
+ ← Back + +
+
+
+{% endblock %} diff --git a/tests/test_setup_gate.py b/tests/test_setup_gate.py index 9aa11ce..7138162 100644 --- a/tests/test_setup_gate.py +++ b/tests/test_setup_gate.py @@ -12,8 +12,8 @@ class TestSetupGateMiddleware: """Tests for SetupGateMiddleware.""" @pytest.mark.asyncio - async def test_allows_setup_route_when_incomplete(self): - """SetupGateMiddleware allows /setup when setup_complete=False.""" + async def test_allows_setup_subpath_when_incomplete(self): + """SetupGateMiddleware allows /setup/operator when setup_complete=False.""" mock_pool = MagicMock() mock_conn = MagicMock() mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) @@ -21,6 +21,31 @@ class TestSetupGateMiddleware: mock_conn.__aexit__ = AsyncMock() mock_pool.acquire = MagicMock(return_value=mock_conn) + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup/operator") + assert response.status_code == 200 + assert response.json() == {"message": "operator"} + + @pytest.mark.asyncio + async def test_redirects_setup_base_to_wizard_step(self): + """SetupGateMiddleware redirects /setup to wizard step when incomplete.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.fetchval = AsyncMock(return_value=0) # No operators + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + with patch("central.gui.middleware.get_pool", return_value=mock_pool): app = FastAPI() @@ -28,12 +53,16 @@ class TestSetupGateMiddleware: async def setup(): return {"message": "setup"} + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + app.add_middleware(SetupGateMiddleware) - client = TestClient(app) + client = TestClient(app, follow_redirects=False) response = client.get("/setup") - assert response.status_code == 200 - assert response.json() == {"message": "setup"} + assert response.status_code == 302 + assert response.headers["location"] == "/setup/operator" @pytest.mark.asyncio async def test_allows_health_when_incomplete(self): @@ -135,7 +164,7 @@ class TestSetupGateMiddleware: @pytest.mark.asyncio async def test_redirects_setup_when_complete(self): - """SetupGateMiddleware redirects /setup to / when setup_complete=True.""" + """SetupGateMiddleware redirects /setup/* to / when setup_complete=True.""" mock_pool = MagicMock() mock_conn = MagicMock() mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) @@ -154,9 +183,18 @@ class TestSetupGateMiddleware: async def setup(): return {"message": "setup"} + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + app.add_middleware(SetupGateMiddleware) client = TestClient(app, follow_redirects=False) + # Both /setup and /setup/operator should redirect to / response = client.get("/setup") assert response.status_code == 302 assert response.headers["location"] == "/" + + response = client.get("/setup/operator") + assert response.status_code == 302 + assert response.headers["location"] == "/" diff --git a/tests/test_wizard.py b/tests/test_wizard.py new file mode 100644 index 0000000..e92c35c --- /dev/null +++ b/tests/test_wizard.py @@ -0,0 +1,586 @@ +"""Tests for the first-run setup wizard.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from central.gui.routes import ( + setup_operator_form, + setup_operator_submit, + setup_system_form, + setup_system_submit, + setup_keys_form, + setup_keys_submit, + setup_adapters_form, + setup_adapters_submit, + setup_finish_form, + setup_finish_submit, +) +from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step + + +class TestWizardStepRedirect: + """Test wizard step redirect logic.""" + + @pytest.mark.asyncio + async def test_no_operators_redirects_to_operator(self): + """When no operators exist, redirect to /setup/operator.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [0] # No operators + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/operator" + + @pytest.mark.asyncio + async def test_default_tile_url_redirects_to_system(self): + """When map_tile_url is default, redirect to /setup/system.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1] # Has operator + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/system" + + @pytest.mark.asyncio + async def test_no_adapters_touched_redirects_to_keys(self): + """When no adapters have been updated, redirect to /setup/keys.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/keys" + + @pytest.mark.asyncio + async def test_all_steps_complete_redirects_to_finish(self): + """When all steps done, redirect to /setup/finish.""" + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png" + } + + result = await _get_wizard_redirect_step(mock_conn) + assert result == "/setup/finish" + + +class TestSetupOperatorForm: + """Test operator creation form (step 1).""" + + @pytest.mark.asyncio + async def test_get_returns_form(self): + """GET /setup/operator returns the form.""" + mock_request = MagicMock() + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_csrf = MagicMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + result = await setup_operator_form(mock_request, mock_csrf) + + mock_templates.TemplateResponse.assert_called_once() + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["csrf_token"] == "token" + assert context["error"] is None + + +class TestSetupOperatorSubmit: + """Test operator creation submission.""" + + @pytest.mark.asyncio + async def test_password_mismatch_shows_error(self): + """POST with password mismatch re-renders with error.""" + mock_request = MagicMock() + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_pool = MagicMock() + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_operator_submit( + mock_request, + username="admin", + password="password123", + confirm_password="different", + csrf_protect=mock_csrf, + ) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["error"] == "Passwords do not match" + + @pytest.mark.asyncio + async def test_valid_creates_operator_and_redirects(self): + """POST with valid data creates operator and redirects to /setup/system.""" + mock_request = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.side_effect = [ + {"id": 1}, # INSERT RETURNING id + {"session_lifetime_days": 90}, # system settings + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.hash_password", return_value="hashed"): + with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: + mock_session.return_value = ("session_token", datetime.now()) + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_operator_submit( + mock_request, + username="admin", + password="password123", + confirm_password="password123", + csrf_protect=mock_csrf, + ) + + assert result.status_code == 302 + assert result.headers["location"] == "/setup/system" + + +class TestSetupSystemForm: + """Test system settings form (step 2).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/system without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + + mock_csrf = MagicMock() + + result = await setup_system_form(mock_request, mock_csrf) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_authenticated_returns_form(self): + """GET /setup/system with auth returns the form.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "© OpenStreetMap contributors", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_form(mock_request, mock_csrf) + + mock_templates.TemplateResponse.assert_called_once() + + +class TestSetupSystemSubmit: + """Test system settings submission.""" + + @pytest.mark.asyncio + async def test_missing_placeholders_shows_error(self): + """POST without {z},{x},{y} placeholders shows error.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "map_tile_url": "https://example.com/tiles", + "map_attribution": "Test", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "", + "map_attribution": "", + } + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_system_submit(mock_request, mock_csrf) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert "map_tile_url" in context["errors"] + + @pytest.mark.asyncio + async def test_valid_updates_and_redirects(self): + """POST with valid data updates system and redirects to /setup/keys.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "map_tile_url": "https://example.com/{z}/{x}/{y}.png", + "map_attribution": "Test Attribution", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_conn = AsyncMock() + mock_conn.fetchrow.return_value = { + "map_tile_url": "old_url", + "map_attribution": "old_attr", + } + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_system_submit(mock_request, mock_csrf) + + assert result.status_code == 302 + assert result.headers["location"] == "/setup/keys" + + +class TestSetupKeysForm: + """Test API keys form (step 3).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/keys without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + + mock_csrf = MagicMock() + + result = await setup_keys_form(mock_request, mock_csrf) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupKeysSubmit: + """Test API keys submission.""" + + @pytest.mark.asyncio + async def test_next_action_redirects_to_adapters(self): + """POST with action=next redirects to /setup/adapters.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + form_data = MagicMock() + form_data.get = lambda k, default="": {"action": "next"}.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + + # No need to mock get_pool since action="next" returns before it's called + result = await setup_keys_submit(mock_request, mock_csrf) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/adapters" + + @pytest.mark.asyncio + async def test_add_key_creates_and_rerenders(self): + """POST with action=add creates key and re-renders with success.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + form_data = MagicMock() + form_data.get = lambda k, default="": { + "action": "add", + "alias": "testkey", + "plaintext_key": "secret123", + }.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchrow.side_effect = [ + None, # No existing key + {"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}, + ] + mock_conn.fetch.side_effect = [ + [], # First list + [{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.crypto.encrypt", return_value=b"encrypted"): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_keys_submit(mock_request, mock_csrf) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["success"] == "API key 'testkey' added successfully." + + +class TestSetupAdaptersForm: + """Test adapters configuration form (step 4).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/adapters without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + + mock_csrf = MagicMock() + + result = await setup_adapters_form(mock_request, mock_csrf) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + +class TestSetupFinishForm: + """Test finish page (step 5).""" + + @pytest.mark.asyncio + async def test_unauthenticated_redirects_to_operator(self): + """GET /setup/finish without auth redirects to /setup/operator.""" + mock_request = MagicMock() + mock_request.state.operator = None + + mock_csrf = MagicMock() + + result = await setup_finish_form(mock_request, mock_csrf) + assert result.status_code == 302 + assert result.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_authenticated_shows_summary(self): + """GET /setup/finish with auth shows summary.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_templates = MagicMock() + mock_templates.TemplateResponse.return_value = MagicMock() + + mock_conn = AsyncMock() + mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys + mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"} + mock_conn.fetch.return_value = [ + {"name": "nws", "enabled": True, "cadence_s": 300}, + {"name": "firms", "enabled": False, "cadence_s": 600}, + ] + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") + mock_csrf.set_csrf_cookie = MagicMock() + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await setup_finish_form(mock_request, mock_csrf) + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert context["operator_count"] == 1 + assert context["key_count"] == 2 + assert len(context["adapters"]) == 2 + + +class TestSetupFinishSubmit: + """Test setup completion.""" + + @pytest.mark.asyncio + async def test_marks_setup_complete_and_redirects(self): + """POST /setup/finish marks setup_complete=true and redirects to /.""" + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="admin") + + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_conn + mock_pool.acquire.return_value.__aexit__.return_value = None + + mock_csrf = MagicMock() + mock_csrf.validate_csrf = AsyncMock() + + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_finish_submit(mock_request, mock_csrf) + + assert result.status_code == 302 + assert result.headers["location"] == "/" + mock_conn.execute.assert_called_once() + mock_audit.assert_called_once() + + +class TestSetupGateMiddlewareWizard: + """Test SetupGateMiddleware with wizard paths.""" + + @pytest.mark.asyncio + async def test_allows_setup_operator_when_incomplete(self): + """SetupGateMiddleware allows /setup/operator when setup_complete=False.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator form"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app) + + response = client.get("/setup/operator") + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_redirects_base_setup_to_wizard_step(self): + """SetupGateMiddleware redirects /setup to appropriate wizard step.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.fetchval = AsyncMock(return_value=0) # No operators + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/setup") + async def setup(): + return {"message": "base setup"} + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup") + assert response.status_code == 302 + assert response.headers["location"] == "/setup/operator" + + @pytest.mark.asyncio + async def test_redirects_login_to_setup_when_incomplete(self): + """SetupGateMiddleware redirects /login to /setup when setup_complete=False.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/login") + async def login(): + return {"message": "login"} + + @app.get("/setup") + async def setup(): + return {"message": "setup"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/login") + assert response.status_code == 302 + assert response.headers["location"] == "/setup" + + @pytest.mark.asyncio + async def test_redirects_all_setup_paths_when_complete(self): + """SetupGateMiddleware redirects /setup/* to / when setup_complete=True.""" + from starlette.testclient import TestClient + from fastapi import FastAPI + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True}) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + with patch("central.gui.middleware.get_pool", return_value=mock_pool): + app = FastAPI() + + @app.get("/") + async def index(): + return {"message": "home"} + + @app.get("/setup/operator") + async def setup_operator(): + return {"message": "operator"} + + app.add_middleware(SetupGateMiddleware) + client = TestClient(app, follow_redirects=False) + + response = client.get("/setup/operator") + assert response.status_code == 302 + assert response.headers["location"] == "/"