feat(wizard): implement deferred-commit pattern for setup wizard

Replace the current "POST each step -> DB write -> redirect" architecture
with "collect values across steps in a signed cookie, commit everything
in one transaction at Finish."

Key changes:
- Add wizard.py: WizardState dataclass and cookie helpers
- csrf.py: Add reuse_or_generate_pre_auth_csrf helper
- routes.py: All wizard handlers now use cookie state, no DB writes until finish
- middleware.py: Cookie-based wizard step routing instead of DB queries
- setup_operator.html: Remove "Operator Already Configured" branch

Benefits:
- Back navigation works: can return to any step and edit values
- Atomic commit: all DB writes happen in single transaction at finish
- No orphaned state: failed wizard leaves no DB artifacts
- Simpler auth: pre-auth CSRF for all 5 steps (no session until finish)

Tests updated for new behavior. 287 tests passing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Matt Johnson 2026-05-18 05:58:39 +00:00
commit 52e0f0e616
6 changed files with 689 additions and 1062 deletions

View file

@ -1,11 +1,10 @@
"""Pre-auth CSRF protection for login and setup/operator pages.
"""Pre-auth CSRF protection for login and setup pages.
These routes cannot use session-bound CSRF because no session exists yet.
Uses a simple cookie-based pattern with short-lived tokens.
"""
import secrets
from typing import Optional
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from starlette.requests import Request
@ -34,6 +33,34 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]:
return plain_token, signed_token
def reuse_or_generate_pre_auth_csrf(
request: Request,
secret_key: str,
) -> tuple[str, str | None]:
"""Reuse an existing valid pre-auth CSRF token, or generate new.
Returns (plain_token, signed_token_for_cookie).
If signed_token_for_cookie is None, the existing cookie is
still valid and caller should not call set_pre_auth_csrf_cookie.
If non-None, caller MUST call set_pre_auth_csrf_cookie with
it to persist the new value.
"""
cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE)
if cookie_value:
serializer = _get_serializer(secret_key)
try:
plain_token = serializer.loads(
cookie_value,
max_age=PRE_AUTH_CSRF_MAX_AGE,
)
return plain_token, None # reuse existing
except (BadSignature, SignatureExpired):
pass # fall through to generate
plain_token, signed_token = generate_pre_auth_csrf(secret_key)
return plain_token, signed_token
def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None:
"""Set the pre-auth CSRF cookie on a response."""
response.set_cookie(

View file

@ -16,7 +16,7 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup")
# Paths that don't require authentication
AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"}
AUTH_EXEMPT_PREFIXES = ("/static/",)
AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/")
def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool:
@ -29,33 +29,14 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool:
return False
async def _get_wizard_redirect_step(conn) -> str:
"""Determine which wizard step to redirect to based on DB state."""
# Check if any operators exist
op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
if op_count == 0:
def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str:
"""Determine wizard redirect step from cookie state."""
from central.gui.wizard import get_wizard_state, get_step_route
state = get_wizard_state(request, csrf_secret)
if state is None:
return "/setup/operator"
# Check if system settings have been configured (map_tile_url not default)
sys_row = await conn.fetchrow(
"SELECT map_tile_url FROM config.system WHERE id = true"
)
default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
if sys_row is None or sys_row["map_tile_url"] == default_tile:
return "/setup/system"
# Keys step is optional, so check adapters have been reviewed
# We consider adapters reviewed if any adapter has a non-null updated_at
# (meaning it was explicitly saved during setup)
adapters_touched = await conn.fetchval(
"SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL"
)
if adapters_touched == 0:
# Go to keys first, then adapters
return "/setup/keys"
# All steps done, go to finish
return "/setup/finish"
return get_step_route(state.wizard_step)
class SetupGateMiddleware(BaseHTTPMiddleware):
@ -85,13 +66,16 @@ class SetupGateMiddleware(BaseHTTPMiddleware):
if not setup_complete:
# Setup not complete - only allow setup paths and static/health
if path.startswith("/setup"):
# Allow all /setup/* paths (handler will enforce auth)
# Allow all /setup/* paths
# But /setup with no subpath should redirect to appropriate step
if path == "/setup" or path == "/setup/":
try:
async with pool.acquire() as conn:
redirect_step = await _get_wizard_redirect_step(conn)
return RedirectResponse(url=redirect_step, status_code=302)
from central.bootstrap_config import get_settings
settings = get_settings()
redirect_step = _get_wizard_redirect_from_cookie(
request, settings.csrf_secret
)
return RedirectResponse(url=redirect_step, status_code=302)
except Exception:
logger.warning("Failed to determine wizard step", exc_info=True)
return RedirectResponse(url="/setup/operator", status_code=302)
@ -139,7 +123,7 @@ class SessionMiddleware(BaseHTTPMiddleware):
request.state.operator = None
request.state.csrf_token = None
# Check if auth is required
# Check if auth is required - setup paths are exempt during wizard
if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES):
if request.state.operator is None:
return RedirectResponse(url="/login", status_code=302)

File diff suppressed because it is too large Load diff

View file

@ -7,17 +7,6 @@
{% include "_wizard_header.html" %}
{% endwith %}
{% if existing_operator %}
<article>
<header>
<h1>Operator Already Configured</h1>
</header>
<p>The operator account <strong>{{ existing_operator.username }}</strong> has been created.</p>
<div style="display: flex; gap: 1rem; margin-top: 1rem;">
<a href="/setup/system" role="button">Next &rarr;</a>
</div>
</article>
{% else %}
<article>
<header>
<h1>Create Operator Account</h1>
@ -53,5 +42,4 @@
<button type="submit">Create Operator &rarr;</button>
</form>
</article>
{% endif %}
{% endblock %}

131
src/central/gui/wizard.py Normal file
View file

@ -0,0 +1,131 @@
"""Wizard state management for deferred-commit setup flow.
The wizard collects configuration across 5 steps and commits everything
atomically at the final step. State is carried in a signed cookie.
"""
import base64
from dataclasses import dataclass, field, asdict
from typing import Any
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from starlette.requests import Request
from starlette.responses import Response
# 1 hour max age for wizard cookie
WIZARD_MAX_AGE = 3600
WIZARD_COOKIE = "central_wizard"
@dataclass
class WizardOperator:
"""Operator data collected in step 1."""
username: str
password_hash: str
@dataclass
class WizardSystem:
"""System settings collected in step 2."""
map_tile_url: str
map_attribution: str
@dataclass
class WizardApiKey:
"""API key collected in step 3."""
alias: str
encrypted_value_b64: str # base64-encoded encrypted value
@dataclass
class WizardAdapter:
"""Adapter config collected in step 4."""
enabled: bool
cadence_s: int
settings: dict[str, Any]
@dataclass
class WizardState:
"""Complete wizard state carried across all steps."""
wizard_step: int = 1
operator: dict | None = None
system: dict | None = None
api_keys: list[dict] = field(default_factory=list)
adapters: dict[str, dict] | None = None
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
"wizard_step": self.wizard_step,
"operator": self.operator,
"system": self.system,
"api_keys": self.api_keys,
"adapters": self.adapters,
}
@classmethod
def from_dict(cls, data: dict) -> "WizardState":
"""Create from dictionary."""
return cls(
wizard_step=data.get("wizard_step", 1),
operator=data.get("operator"),
system=data.get("system"),
api_keys=data.get("api_keys", []),
adapters=data.get("adapters"),
)
def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer:
"""Get a timed serializer for wizard state."""
return URLSafeTimedSerializer(secret_key, salt="wizard-state")
def get_wizard_state(request: Request, secret_key: str) -> WizardState | None:
"""Decode wizard state from cookie.
Returns WizardState if valid, None if missing/invalid/expired.
"""
cookie_value = request.cookies.get(WIZARD_COOKIE)
if not cookie_value:
return None
serializer = _get_wizard_serializer(secret_key)
try:
data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE)
return WizardState.from_dict(data)
except (BadSignature, SignatureExpired):
return None
def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None:
"""Set the wizard state cookie on a response."""
serializer = _get_wizard_serializer(secret_key)
signed_value = serializer.dumps(state.to_dict())
response.set_cookie(
WIZARD_COOKIE,
signed_value,
max_age=WIZARD_MAX_AGE,
path="/setup",
httponly=True,
samesite="lax",
)
def clear_wizard_cookie(response: Response) -> None:
"""Remove the wizard state cookie."""
response.delete_cookie(WIZARD_COOKIE, path="/setup")
def get_step_route(step: int) -> str:
"""Get the route for a wizard step number."""
routes = {
1: "/setup/operator",
2: "/setup/system",
3: "/setup/keys",
4: "/setup/adapters",
5: "/setup/finish",
}
return routes.get(step, "/setup/operator")

View file

@ -1,4 +1,4 @@
"""Tests for the first-run setup wizard."""
"""Tests for the first-run setup wizard with deferred-commit pattern."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
@ -11,60 +11,38 @@ from central.gui.routes import (
setup_system_submit,
setup_keys_form,
setup_keys_submit,
setup_adapters_form,
setup_adapters_submit,
setup_finish_form,
setup_finish_submit,
)
from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step
from central.gui.middleware import SetupGateMiddleware
from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie
class TestWizardStepRedirect:
"""Test wizard step redirect logic."""
"""Test wizard step redirect logic based on cookie state."""
@pytest.mark.asyncio
async def test_no_operators_redirects_to_operator(self):
"""When no operators exist, redirect to /setup/operator."""
mock_conn = AsyncMock()
mock_conn.fetchval.side_effect = [0] # No operators
def test_no_cookie_redirects_to_operator(self):
"""When no wizard cookie exists, redirect to /setup/operator."""
from central.gui.middleware import _get_wizard_redirect_from_cookie
result = await _get_wizard_redirect_step(mock_conn)
mock_request = MagicMock()
mock_request.cookies = {}
result = _get_wizard_redirect_from_cookie(mock_request, "testsecret")
assert result == "/setup/operator"
@pytest.mark.asyncio
async def test_default_tile_url_redirects_to_system(self):
"""When map_tile_url is default, redirect to /setup/system."""
mock_conn = AsyncMock()
mock_conn.fetchval.side_effect = [1] # Has operator
mock_conn.fetchrow.return_value = {
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
}
def test_cookie_step_2_redirects_to_system(self):
"""When wizard_step=2 in cookie, redirect to /setup/system."""
from central.gui.wizard import get_step_route
result = await _get_wizard_redirect_step(mock_conn)
result = get_step_route(2)
assert result == "/setup/system"
@pytest.mark.asyncio
async def test_no_adapters_touched_redirects_to_keys(self):
"""When no adapters have been updated, redirect to /setup/keys."""
mock_conn = AsyncMock()
mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched
mock_conn.fetchrow.return_value = {
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
}
def test_cookie_step_5_redirects_to_finish(self):
"""When wizard_step=5 in cookie, redirect to /setup/finish."""
from central.gui.wizard import get_step_route
result = await _get_wizard_redirect_step(mock_conn)
assert result == "/setup/keys"
@pytest.mark.asyncio
async def test_all_steps_complete_redirects_to_finish(self):
"""When all steps done, redirect to /setup/finish."""
mock_conn = AsyncMock()
mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched
mock_conn.fetchrow.return_value = {
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
}
result = await _get_wizard_redirect_step(mock_conn)
result = get_step_route(5)
assert result == "/setup/finish"
@ -72,63 +50,26 @@ class TestSetupOperatorForm:
"""Test operator creation form (step 1)."""
@pytest.mark.asyncio
async def test_get_returns_form(self):
"""GET /setup/operator returns the form when no operator exists."""
async def test_get_returns_form_without_prefill(self):
"""GET /setup/operator returns the form when no wizard cookie exists."""
mock_request = MagicMock()
mock_request.cookies = {}
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchrow.return_value = None # No operator exists
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret"
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
result = await setup_operator_form(mock_request)
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
result = await setup_operator_form(mock_request)
mock_templates.TemplateResponse.assert_called_once()
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert "csrf_token" in context and context["csrf_token"]
assert context["error"] is None
assert context["existing_operator"] is None
@pytest.mark.asyncio
async def test_get_returns_confirmation_when_operator_exists(self):
"""GET /setup/operator shows confirmation when operator already exists."""
mock_request = MagicMock()
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.body = b"Operator Already Configured"
mock_templates.TemplateResponse.return_value = mock_response
mock_conn = AsyncMock()
mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret"
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
result = await setup_operator_form(mock_request)
mock_templates.TemplateResponse.assert_called_once()
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["existing_operator"] == {"username": "admin"}
assert context["error"] is None
assert context["form_data"] is None
class TestSetupOperatorSubmit:
@ -138,28 +79,17 @@ class TestSetupOperatorSubmit:
async def test_password_mismatch_shows_error(self):
"""POST with password mismatch re-renders with error."""
mock_request = MagicMock()
mock_request.state.csrf_token = "test_csrf"
mock_request.form = AsyncMock(return_value={
"csrf_token": "test_csrf",
"username": "testuser",
"password": "password1",
"confirm_password": "password2", # Mismatch
})
mock_request.cookies = {}
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"})
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchval.return_value = 0 # No existing operators
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret"
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")):
result = await setup_operator_submit(
mock_request,
username="testuser",
@ -172,374 +102,43 @@ class TestSetupOperatorSubmit:
assert context["error"] == "Passwords do not match"
@pytest.mark.asyncio
async def test_valid_creates_operator_and_redirects(self):
"""POST with valid data creates operator and redirects to /setup/system."""
async def test_valid_creates_wizard_cookie_and_redirects(self):
"""POST with valid data creates wizard cookie and redirects to /setup/system."""
mock_request = MagicMock()
mock_request.state.csrf_token = "test_csrf"
mock_request.form = AsyncMock(return_value={
"csrf_token": "test_csrf",
"username": "testuser",
"password": "password123",
"confirm_password": "password123",
})
mock_request.cookies = {}
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"})
mock_conn = AsyncMock()
mock_conn.fetchval.return_value = 0 # No existing operators
mock_conn.fetchrow.side_effect = [
{"id": 1}, # INSERT RETURNING id
{"session_lifetime_days": 90}, # system settings
]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret"
with patch("central.gui.routes.hash_password", return_value="hashed"):
with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session:
mock_session.return_value = ("session_token", datetime.now(), "csrf_token")
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
result = await setup_operator_submit(
mock_request,
username="testuser",
password="password123",
confirm_password="password123",
)
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.hash_password", return_value="hashed_pw"):
result = await setup_operator_submit(
mock_request,
username="testuser",
password="password123",
confirm_password="password123",
)
assert result.status_code == 302
assert result.headers["location"] == "/setup/system"
@pytest.mark.asyncio
async def test_post_when_operator_exists_shows_confirmation(self):
"""POST when operator exists returns 200 with confirmation, no insert."""
mock_request = MagicMock()
mock_request.form = AsyncMock(return_value={
"csrf_token": "test_csrf",
"username": "testuser",
"password": "password123",
"confirm_password": "password123",
})
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_templates.TemplateResponse.return_value = mock_response
mock_conn = AsyncMock()
mock_conn.fetchval.return_value = 1 # Operator already exists
mock_conn.fetchrow.return_value = {"username": "existing_admin"}
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
mock_request.state.csrf_token = "test_csrf"
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret"
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
result = await setup_operator_submit(
mock_request,
username="testuser",
password="password123",
confirm_password="password123",
)
# Should return 200, not 500 or redirect
assert result.status_code == 200
# Should render confirmation state
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["existing_operator"] == {"username": "existing_admin"}
# Should NOT call write_audit (no insert happened)
mock_audit.assert_not_called()
class TestSetupSystemForm:
"""Test system settings form (step 2)."""
@pytest.mark.asyncio
async def test_unauthenticated_redirects_to_operator(self):
"""GET /setup/system without auth redirects to /setup/operator."""
async def test_no_wizard_cookie_redirects_to_operator(self):
"""GET /setup/system without wizard cookie redirects to /setup/operator."""
mock_request = MagicMock()
mock_request.state.operator = None
result = await setup_system_form(mock_request)
mock_request.cookies = {}
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
result = await setup_system_form(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/operator"
@pytest.mark.asyncio
async def test_authenticated_returns_form(self):
"""GET /setup/system with auth returns the form."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchrow.return_value = {
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
"map_attribution": "&copy; OpenStreetMap contributors",
}
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
result = await setup_system_form(mock_request)
mock_templates.TemplateResponse.assert_called_once()
class TestSetupSystemSubmit:
"""Test system settings submission."""
@pytest.mark.asyncio
async def test_missing_placeholders_shows_error(self):
"""POST without {z},{x},{y} placeholders shows error."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_request.state.csrf_token = "test_csrf_token"
form_data = MagicMock()
form_data.get = lambda k, default="": {
"csrf_token": "test_csrf_token",
"map_tile_url": "https://example.com/tiles",
"map_attribution": "Test",
}.get(k, default)
mock_request.form = AsyncMock(return_value=form_data)
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchrow.return_value = {
"map_tile_url": "",
"map_attribution": "",
}
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
result = await setup_system_submit(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert "map_tile_url" in context["errors"]
@pytest.mark.asyncio
async def test_valid_updates_and_redirects(self):
"""POST with valid data updates system and redirects to /setup/keys."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_request.state.csrf_token = "test_csrf_token"
form_data = MagicMock()
form_data.get = lambda k, default="": {
"csrf_token": "test_csrf_token",
"map_tile_url": "https://example.com/{z}/{x}/{y}.png",
"map_attribution": "Test Attribution",
}.get(k, default)
mock_request.form = AsyncMock(return_value=form_data)
mock_conn = AsyncMock()
mock_conn.fetchrow.return_value = {
"map_tile_url": "old_url",
"map_attribution": "old_attr",
}
mock_conn.execute = AsyncMock()
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
result = await setup_system_submit(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/keys"
class TestSetupKeysForm:
"""Test API keys form (step 3)."""
@pytest.mark.asyncio
async def test_unauthenticated_redirects_to_operator(self):
"""GET /setup/keys without auth redirects to /setup/operator."""
mock_request = MagicMock()
mock_request.state.operator = None
result = await setup_keys_form(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/operator"
class TestSetupKeysSubmit:
"""Test API keys submission."""
@pytest.mark.asyncio
async def test_next_action_redirects_to_adapters(self):
"""POST with action=next redirects to /setup/adapters."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_request.state.csrf_token = "test_csrf_token"
form_data = MagicMock()
form_data.get = lambda k, default="": {
"csrf_token": "test_csrf_token",
"action": "next",
}.get(k, default)
mock_request.form = AsyncMock(return_value=form_data)
# No need to mock get_pool since action="next" returns before it's called
result = await setup_keys_submit(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/adapters"
@pytest.mark.asyncio
async def test_add_key_creates_and_rerenders(self):
"""POST with action=add creates key and re-renders with success."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_request.state.csrf_token = "test_csrf_token"
form_data = MagicMock()
form_data.get = lambda k, default="": {
"csrf_token": "test_csrf_token",
"action": "add",
"alias": "testkey",
"plaintext_key": "secret123",
}.get(k, default)
mock_request.form = AsyncMock(return_value=form_data)
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchrow.side_effect = [
None, # No existing key
{"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)},
]
mock_conn.fetch.side_effect = [
[], # First list
[{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert
]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.crypto.encrypt", return_value=b"encrypted"):
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
result = await setup_keys_submit(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["success"] == "API key 'testkey' added successfully."
class TestSetupAdaptersForm:
"""Test adapters configuration form (step 4)."""
@pytest.mark.asyncio
async def test_unauthenticated_redirects_to_operator(self):
"""GET /setup/adapters without auth redirects to /setup/operator."""
mock_request = MagicMock()
mock_request.state.operator = None
result = await setup_adapters_form(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/operator"
class TestSetupFinishForm:
"""Test finish page (step 5)."""
@pytest.mark.asyncio
async def test_unauthenticated_redirects_to_operator(self):
"""GET /setup/finish without auth redirects to /setup/operator."""
mock_request = MagicMock()
mock_request.state.operator = None
result = await setup_finish_form(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/setup/operator"
@pytest.mark.asyncio
async def test_authenticated_shows_summary(self):
"""GET /setup/finish with auth shows summary."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_templates = MagicMock()
mock_templates.TemplateResponse.return_value = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys
mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"}
mock_conn.fetch.return_value = [
{"name": "nws", "enabled": True, "cadence_s": 300},
{"name": "firms", "enabled": False, "cadence_s": 600},
]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
result = await setup_finish_form(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["operator_count"] == 1
assert context["key_count"] == 2
assert len(context["adapters"]) == 2
class TestSetupFinishSubmit:
"""Test setup completion."""
@pytest.mark.asyncio
async def test_marks_setup_complete_and_redirects(self):
"""POST /setup/finish marks setup_complete=true and redirects to /."""
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="admin")
mock_request.state.csrf_token = "test_csrf_token"
# Mock form with CSRF token
form_data = MagicMock()
form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default)
mock_request.form = AsyncMock(return_value=form_data)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
mock_pool.acquire.return_value.__aexit__.return_value = None
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
result = await setup_finish_submit(mock_request)
assert result.status_code == 302
assert result.headers["location"] == "/"
mock_conn.execute.assert_called_once()
mock_audit.assert_called_once()
class TestSetupGateMiddlewareWizard:
"""Test SetupGateMiddleware with wizard paths."""
@ -570,69 +169,6 @@ class TestSetupGateMiddlewareWizard:
response = client.get("/setup/operator")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_redirects_base_setup_to_wizard_step(self):
"""SetupGateMiddleware redirects /setup to appropriate wizard step."""
from starlette.testclient import TestClient
from fastapi import FastAPI
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
mock_conn.fetchval = AsyncMock(return_value=0) # No operators
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/setup")
async def setup():
return {"message": "base setup"}
@app.get("/setup/operator")
async def setup_operator():
return {"message": "operator"}
app.add_middleware(SetupGateMiddleware)
client = TestClient(app, follow_redirects=False)
response = client.get("/setup")
assert response.status_code == 302
assert response.headers["location"] == "/setup/operator"
@pytest.mark.asyncio
async def test_redirects_login_to_setup_when_incomplete(self):
"""SetupGateMiddleware redirects /login to /setup when setup_complete=False."""
from starlette.testclient import TestClient
from fastapi import FastAPI
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/login")
async def login():
return {"message": "login"}
@app.get("/setup")
async def setup():
return {"message": "setup"}
app.add_middleware(SetupGateMiddleware)
client = TestClient(app, follow_redirects=False)
response = client.get("/login")
assert response.status_code == 302
assert response.headers["location"] == "/setup"
@pytest.mark.asyncio
async def test_redirects_all_setup_paths_when_complete(self):
"""SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""