mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
1b-8: Wizard redesign (deferred-commit) + map fixes + favicon CSRF race fix (#27)
* feat(wizard): implement deferred-commit pattern for setup wizard Replace the current "POST each step -> DB write -> redirect" architecture with "collect values across steps in a signed cookie, commit everything in one transaction at Finish." Key changes: - Add wizard.py: WizardState dataclass and cookie helpers - csrf.py: Add reuse_or_generate_pre_auth_csrf helper - routes.py: All wizard handlers now use cookie state, no DB writes until finish - middleware.py: Cookie-based wizard step routing instead of DB queries - setup_operator.html: Remove "Operator Already Configured" branch Benefits: - Back navigation works: can return to any step and edit values - Atomic commit: all DB writes happen in single transaction at finish - No orphaned state: failed wizard leaves no DB artifacts - Simpler auth: pre-auth CSRF for all 5 steps (no session until finish) Tests updated for new behavior. 287 tests passing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(templates): correct SRI hashes for leaflet.draw assets The integrity hashes for leaflet.draw.css and leaflet.draw.js were incorrect, causing browsers to silently block these resources. This broke the Leaflet.draw toolbar and map rendering for FIRMS/USGS adapter region pickers. Updated both setup_adapters.html and adapters_edit.html with the correct sha512 hashes computed from the actual CDN files. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(gui): return 204 for browser-noise paths to prevent CSRF races Browser requests for /favicon.ico, /apple-touch-icon.png, etc. were triggering parallel GET requests that could race with form loads, causing CSRF token rotation issues. Added BROWSER_NOISE_PATHS constant and early 204 response in both SetupGateMiddleware and SessionMiddleware to short-circuit these requests before any cookie/token handling occurs. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Matt Johnson <mj@k7zvx.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
246cd75051
commit
78b6fcf150
8 changed files with 710 additions and 1066 deletions
|
|
@ -1,11 +1,10 @@
|
|||
"""Pre-auth CSRF protection for login and setup/operator pages.
|
||||
"""Pre-auth CSRF protection for login and setup pages.
|
||||
|
||||
These routes cannot use session-bound CSRF because no session exists yet.
|
||||
Uses a simple cookie-based pattern with short-lived tokens.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from starlette.requests import Request
|
||||
|
|
@ -34,6 +33,34 @@ def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]:
|
|||
return plain_token, signed_token
|
||||
|
||||
|
||||
def reuse_or_generate_pre_auth_csrf(
|
||||
request: Request,
|
||||
secret_key: str,
|
||||
) -> tuple[str, str | None]:
|
||||
"""Reuse an existing valid pre-auth CSRF token, or generate new.
|
||||
|
||||
Returns (plain_token, signed_token_for_cookie).
|
||||
If signed_token_for_cookie is None, the existing cookie is
|
||||
still valid and caller should not call set_pre_auth_csrf_cookie.
|
||||
If non-None, caller MUST call set_pre_auth_csrf_cookie with
|
||||
it to persist the new value.
|
||||
"""
|
||||
cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE)
|
||||
if cookie_value:
|
||||
serializer = _get_serializer(secret_key)
|
||||
try:
|
||||
plain_token = serializer.loads(
|
||||
cookie_value,
|
||||
max_age=PRE_AUTH_CSRF_MAX_AGE,
|
||||
)
|
||||
return plain_token, None # reuse existing
|
||||
except (BadSignature, SignatureExpired):
|
||||
pass # fall through to generate
|
||||
|
||||
plain_token, signed_token = generate_pre_auth_csrf(secret_key)
|
||||
return plain_token, signed_token
|
||||
|
||||
|
||||
def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None:
|
||||
"""Set the pre-auth CSRF cookie on a response."""
|
||||
response.set_cookie(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,15 @@ SETUP_EXEMPT_PREFIXES = ("/static/", "/setup")
|
|||
|
||||
# Paths that don't require authentication
|
||||
AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"}
|
||||
AUTH_EXEMPT_PREFIXES = ("/static/",)
|
||||
AUTH_EXEMPT_PREFIXES = ("/static/", "/setup/")
|
||||
|
||||
# Browser-noise paths that trigger CSRF race conditions
|
||||
BROWSER_NOISE_PATHS = {
|
||||
"/favicon.ico",
|
||||
"/apple-touch-icon.png",
|
||||
"/apple-touch-icon-precomposed.png",
|
||||
"/robots.txt",
|
||||
}
|
||||
|
||||
|
||||
def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool:
|
||||
|
|
@ -29,33 +37,14 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
async def _get_wizard_redirect_step(conn) -> str:
|
||||
"""Determine which wizard step to redirect to based on DB state."""
|
||||
# Check if any operators exist
|
||||
op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
|
||||
if op_count == 0:
|
||||
def _get_wizard_redirect_from_cookie(request: Request, csrf_secret: str) -> str:
|
||||
"""Determine wizard redirect step from cookie state."""
|
||||
from central.gui.wizard import get_wizard_state, get_step_route
|
||||
|
||||
state = get_wizard_state(request, csrf_secret)
|
||||
if state is None:
|
||||
return "/setup/operator"
|
||||
|
||||
# Check if system settings have been configured (map_tile_url not default)
|
||||
sys_row = await conn.fetchrow(
|
||||
"SELECT map_tile_url FROM config.system WHERE id = true"
|
||||
)
|
||||
default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||
if sys_row is None or sys_row["map_tile_url"] == default_tile:
|
||||
return "/setup/system"
|
||||
|
||||
# Keys step is optional, so check adapters have been reviewed
|
||||
# We consider adapters reviewed if any adapter has a non-null updated_at
|
||||
# (meaning it was explicitly saved during setup)
|
||||
adapters_touched = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL"
|
||||
)
|
||||
if adapters_touched == 0:
|
||||
# Go to keys first, then adapters
|
||||
return "/setup/keys"
|
||||
|
||||
# All steps done, go to finish
|
||||
return "/setup/finish"
|
||||
return get_step_route(state.wizard_step)
|
||||
|
||||
|
||||
class SetupGateMiddleware(BaseHTTPMiddleware):
|
||||
|
|
@ -64,6 +53,10 @@ class SetupGateMiddleware(BaseHTTPMiddleware):
|
|||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
path = request.url.path
|
||||
|
||||
# Short-circuit browser-noise requests that cause CSRF races
|
||||
if path in BROWSER_NOISE_PATHS:
|
||||
return Response(status_code=204)
|
||||
|
||||
# Check setup status from database
|
||||
pool = get_pool()
|
||||
if pool is None:
|
||||
|
|
@ -85,13 +78,16 @@ class SetupGateMiddleware(BaseHTTPMiddleware):
|
|||
if not setup_complete:
|
||||
# Setup not complete - only allow setup paths and static/health
|
||||
if path.startswith("/setup"):
|
||||
# Allow all /setup/* paths (handler will enforce auth)
|
||||
# Allow all /setup/* paths
|
||||
# But /setup with no subpath should redirect to appropriate step
|
||||
if path == "/setup" or path == "/setup/":
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
redirect_step = await _get_wizard_redirect_step(conn)
|
||||
return RedirectResponse(url=redirect_step, status_code=302)
|
||||
from central.bootstrap_config import get_settings
|
||||
settings = get_settings()
|
||||
redirect_step = _get_wizard_redirect_from_cookie(
|
||||
request, settings.csrf_secret
|
||||
)
|
||||
return RedirectResponse(url=redirect_step, status_code=302)
|
||||
except Exception:
|
||||
logger.warning("Failed to determine wizard step", exc_info=True)
|
||||
return RedirectResponse(url="/setup/operator", status_code=302)
|
||||
|
|
@ -118,6 +114,11 @@ class SessionMiddleware(BaseHTTPMiddleware):
|
|||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
path = request.url.path
|
||||
|
||||
# Short-circuit browser-noise requests (already handled by SetupGateMiddleware,
|
||||
# but this protects if middleware order changes)
|
||||
if path in BROWSER_NOISE_PATHS:
|
||||
return Response(status_code=204)
|
||||
|
||||
# Initialize state
|
||||
request.state.operator = None
|
||||
request.state.csrf_token = None
|
||||
|
|
@ -139,7 +140,7 @@ class SessionMiddleware(BaseHTTPMiddleware):
|
|||
request.state.operator = None
|
||||
request.state.csrf_token = None
|
||||
|
||||
# Check if auth is required
|
||||
# Check if auth is required - setup paths are exempt during wizard
|
||||
if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES):
|
||||
if request.state.operator is None:
|
||||
return RedirectResponse(url="/login", status_code=302)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
{% block head %}
|
||||
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" integrity="sha256-p4NxAoJBhIIN+hmNHrzRCf9tD/miZyoHS5obTRR9BMY=" crossorigin="">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css" integrity="sha512-gc3xjCmIy673V6MyOAZhIW93xhM9ei1I+gLbmFjUHIjocENRsLX/QUE1htk5q1XV2D/iie/VQ8DXI6Uj8GB1Og==" crossorigin="anonymous">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css" integrity="sha512-gc3xjCmIy673V6MyOAZhIW93xhM9ei1I+gLbmFjUHIjocENRsLX/QUE1htk5q1XV2D/iie/VQ8DXI6Vu8bexvQ==" crossorigin="anonymous">
|
||||
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js" integrity="sha256-20nQCchB9co0qIjJZRGuk2/Z9VM+kNiyxNV1lvTlZBo=" crossorigin=""></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js" integrity="sha512-ozq8xQKq6urvuU6jNgkfqAmT7jKN2XumbrX1JiB3TnF7tI48DPI4Ber9dLJ0ikXiRg9G9Vl2jXwqjZ5LDGQ3g==" crossorigin="anonymous"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js" integrity="sha512-ozq8xQKq6urvuU6jNgkfqAmT7jKN2XumbrX1JiB3TnF7tI48DPI4Gy1GXKD/V3EExgAs1V+pRO7vwtS1LHg0Gw==" crossorigin="anonymous"></script>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
{% block head %}
|
||||
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" integrity="sha256-p4NxAoJBhIIN+hmNHrzRCf9tD/miZyoHS5obTRR9BMY=" crossorigin="">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css" integrity="sha512-gc3xjCmIy673V6MyOAZhIW93xhM9ei1I+gLbmFjUHIjocENRsLX/QUE1htk5q1XV2D/iie/VQ8DXI6Uj8GB1Og==" crossorigin="anonymous">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css" integrity="sha512-gc3xjCmIy673V6MyOAZhIW93xhM9ei1I+gLbmFjUHIjocENRsLX/QUE1htk5q1XV2D/iie/VQ8DXI6Vu8bexvQ==" crossorigin="anonymous">
|
||||
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js" integrity="sha256-20nQCchB9co0qIjJZRGuk2/Z9VM+kNiyxNV1lvTlZBo=" crossorigin=""></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js" integrity="sha512-ozq8xQKq6urvuU6jNgkfqAmT7jKN2XumbrX1JiB3TnF7tI48DPI4Ber9dLJ0ikXiRg9G9Vl2jXwqjZ5LDGQ3g==" crossorigin="anonymous"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js" integrity="sha512-ozq8xQKq6urvuU6jNgkfqAmT7jKN2XumbrX1JiB3TnF7tI48DPI4Gy1GXKD/V3EExgAs1V+pRO7vwtS1LHg0Gw==" crossorigin="anonymous"></script>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
|
|
|
|||
|
|
@ -7,17 +7,6 @@
|
|||
{% include "_wizard_header.html" %}
|
||||
{% endwith %}
|
||||
|
||||
{% if existing_operator %}
|
||||
<article>
|
||||
<header>
|
||||
<h1>Operator Already Configured</h1>
|
||||
</header>
|
||||
<p>The operator account <strong>{{ existing_operator.username }}</strong> has been created.</p>
|
||||
<div style="display: flex; gap: 1rem; margin-top: 1rem;">
|
||||
<a href="/setup/system" role="button">Next →</a>
|
||||
</div>
|
||||
</article>
|
||||
{% else %}
|
||||
<article>
|
||||
<header>
|
||||
<h1>Create Operator Account</h1>
|
||||
|
|
@ -53,5 +42,4 @@
|
|||
<button type="submit">Create Operator →</button>
|
||||
</form>
|
||||
</article>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
|
|
|||
131
src/central/gui/wizard.py
Normal file
131
src/central/gui/wizard.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Wizard state management for deferred-commit setup flow.
|
||||
|
||||
The wizard collects configuration across 5 steps and commits everything
|
||||
atomically at the final step. State is carried in a signed cookie.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any
|
||||
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
# 1 hour max age for wizard cookie
|
||||
WIZARD_MAX_AGE = 3600
|
||||
WIZARD_COOKIE = "central_wizard"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WizardOperator:
|
||||
"""Operator data collected in step 1."""
|
||||
username: str
|
||||
password_hash: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WizardSystem:
|
||||
"""System settings collected in step 2."""
|
||||
map_tile_url: str
|
||||
map_attribution: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WizardApiKey:
|
||||
"""API key collected in step 3."""
|
||||
alias: str
|
||||
encrypted_value_b64: str # base64-encoded encrypted value
|
||||
|
||||
|
||||
@dataclass
|
||||
class WizardAdapter:
|
||||
"""Adapter config collected in step 4."""
|
||||
enabled: bool
|
||||
cadence_s: int
|
||||
settings: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WizardState:
|
||||
"""Complete wizard state carried across all steps."""
|
||||
wizard_step: int = 1
|
||||
operator: dict | None = None
|
||||
system: dict | None = None
|
||||
api_keys: list[dict] = field(default_factory=list)
|
||||
adapters: dict[str, dict] | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"wizard_step": self.wizard_step,
|
||||
"operator": self.operator,
|
||||
"system": self.system,
|
||||
"api_keys": self.api_keys,
|
||||
"adapters": self.adapters,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "WizardState":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
wizard_step=data.get("wizard_step", 1),
|
||||
operator=data.get("operator"),
|
||||
system=data.get("system"),
|
||||
api_keys=data.get("api_keys", []),
|
||||
adapters=data.get("adapters"),
|
||||
)
|
||||
|
||||
|
||||
def _get_wizard_serializer(secret_key: str) -> URLSafeTimedSerializer:
|
||||
"""Get a timed serializer for wizard state."""
|
||||
return URLSafeTimedSerializer(secret_key, salt="wizard-state")
|
||||
|
||||
|
||||
def get_wizard_state(request: Request, secret_key: str) -> WizardState | None:
|
||||
"""Decode wizard state from cookie.
|
||||
|
||||
Returns WizardState if valid, None if missing/invalid/expired.
|
||||
"""
|
||||
cookie_value = request.cookies.get(WIZARD_COOKIE)
|
||||
if not cookie_value:
|
||||
return None
|
||||
|
||||
serializer = _get_wizard_serializer(secret_key)
|
||||
try:
|
||||
data = serializer.loads(cookie_value, max_age=WIZARD_MAX_AGE)
|
||||
return WizardState.from_dict(data)
|
||||
except (BadSignature, SignatureExpired):
|
||||
return None
|
||||
|
||||
|
||||
def set_wizard_cookie(response: Response, state: WizardState, secret_key: str) -> None:
|
||||
"""Set the wizard state cookie on a response."""
|
||||
serializer = _get_wizard_serializer(secret_key)
|
||||
signed_value = serializer.dumps(state.to_dict())
|
||||
response.set_cookie(
|
||||
WIZARD_COOKIE,
|
||||
signed_value,
|
||||
max_age=WIZARD_MAX_AGE,
|
||||
path="/setup",
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
|
||||
def clear_wizard_cookie(response: Response) -> None:
|
||||
"""Remove the wizard state cookie."""
|
||||
response.delete_cookie(WIZARD_COOKIE, path="/setup")
|
||||
|
||||
|
||||
def get_step_route(step: int) -> str:
|
||||
"""Get the route for a wizard step number."""
|
||||
routes = {
|
||||
1: "/setup/operator",
|
||||
2: "/setup/system",
|
||||
3: "/setup/keys",
|
||||
4: "/setup/adapters",
|
||||
5: "/setup/finish",
|
||||
}
|
||||
return routes.get(step, "/setup/operator")
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for the first-run setup wizard."""
|
||||
"""Tests for the first-run setup wizard with deferred-commit pattern."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -11,60 +11,38 @@ from central.gui.routes import (
|
|||
setup_system_submit,
|
||||
setup_keys_form,
|
||||
setup_keys_submit,
|
||||
setup_adapters_form,
|
||||
setup_adapters_submit,
|
||||
setup_finish_form,
|
||||
setup_finish_submit,
|
||||
)
|
||||
from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step
|
||||
from central.gui.middleware import SetupGateMiddleware
|
||||
from central.gui.wizard import WizardState, get_wizard_state, set_wizard_cookie
|
||||
|
||||
|
||||
class TestWizardStepRedirect:
|
||||
"""Test wizard step redirect logic."""
|
||||
"""Test wizard step redirect logic based on cookie state."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_operators_redirects_to_operator(self):
|
||||
"""When no operators exist, redirect to /setup/operator."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.side_effect = [0] # No operators
|
||||
def test_no_cookie_redirects_to_operator(self):
|
||||
"""When no wizard cookie exists, redirect to /setup/operator."""
|
||||
from central.gui.middleware import _get_wizard_redirect_from_cookie
|
||||
|
||||
result = await _get_wizard_redirect_step(mock_conn)
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {}
|
||||
|
||||
result = _get_wizard_redirect_from_cookie(mock_request, "testsecret")
|
||||
assert result == "/setup/operator"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_tile_url_redirects_to_system(self):
|
||||
"""When map_tile_url is default, redirect to /setup/system."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.side_effect = [1] # Has operator
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||
}
|
||||
def test_cookie_step_2_redirects_to_system(self):
|
||||
"""When wizard_step=2 in cookie, redirect to /setup/system."""
|
||||
from central.gui.wizard import get_step_route
|
||||
|
||||
result = await _get_wizard_redirect_step(mock_conn)
|
||||
result = get_step_route(2)
|
||||
assert result == "/setup/system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_adapters_touched_redirects_to_keys(self):
|
||||
"""When no adapters have been updated, redirect to /setup/keys."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
|
||||
}
|
||||
def test_cookie_step_5_redirects_to_finish(self):
|
||||
"""When wizard_step=5 in cookie, redirect to /setup/finish."""
|
||||
from central.gui.wizard import get_step_route
|
||||
|
||||
result = await _get_wizard_redirect_step(mock_conn)
|
||||
assert result == "/setup/keys"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_steps_complete_redirects_to_finish(self):
|
||||
"""When all steps done, redirect to /setup/finish."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
|
||||
}
|
||||
|
||||
result = await _get_wizard_redirect_step(mock_conn)
|
||||
result = get_step_route(5)
|
||||
assert result == "/setup/finish"
|
||||
|
||||
|
||||
|
|
@ -72,63 +50,26 @@ class TestSetupOperatorForm:
|
|||
"""Test operator creation form (step 1)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_returns_form(self):
|
||||
"""GET /setup/operator returns the form when no operator exists."""
|
||||
async def test_get_returns_form_without_prefill(self):
|
||||
"""GET /setup/operator returns the form when no wizard cookie exists."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {}
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = None # No operator exists
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret"
|
||||
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
|
||||
result = await setup_operator_form(mock_request)
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
|
||||
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
|
||||
result = await setup_operator_form(mock_request)
|
||||
|
||||
mock_templates.TemplateResponse.assert_called_once()
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert "csrf_token" in context and context["csrf_token"]
|
||||
assert context["error"] is None
|
||||
assert context["existing_operator"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_returns_confirmation_when_operator_exists(self):
|
||||
"""GET /setup/operator shows confirmation when operator already exists."""
|
||||
mock_request = MagicMock()
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b"Operator Already Configured"
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret"
|
||||
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
|
||||
result = await setup_operator_form(mock_request)
|
||||
|
||||
mock_templates.TemplateResponse.assert_called_once()
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["existing_operator"] == {"username": "admin"}
|
||||
assert context["error"] is None
|
||||
assert context["form_data"] is None
|
||||
|
||||
|
||||
class TestSetupOperatorSubmit:
|
||||
|
|
@ -138,28 +79,17 @@ class TestSetupOperatorSubmit:
|
|||
async def test_password_mismatch_shows_error(self):
|
||||
"""POST with password mismatch re-renders with error."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.csrf_token = "test_csrf"
|
||||
mock_request.form = AsyncMock(return_value={
|
||||
"csrf_token": "test_csrf",
|
||||
"username": "testuser",
|
||||
"password": "password1",
|
||||
"confirm_password": "password2", # Mismatch
|
||||
})
|
||||
mock_request.cookies = {}
|
||||
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"})
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.return_value = 0 # No existing operators
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret"
|
||||
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
|
||||
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("test_token", "signed")):
|
||||
result = await setup_operator_submit(
|
||||
mock_request,
|
||||
username="testuser",
|
||||
|
|
@ -172,374 +102,43 @@ class TestSetupOperatorSubmit:
|
|||
assert context["error"] == "Passwords do not match"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_creates_operator_and_redirects(self):
|
||||
"""POST with valid data creates operator and redirects to /setup/system."""
|
||||
async def test_valid_creates_wizard_cookie_and_redirects(self):
|
||||
"""POST with valid data creates wizard cookie and redirects to /setup/system."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.csrf_token = "test_csrf"
|
||||
mock_request.form = AsyncMock(return_value={
|
||||
"csrf_token": "test_csrf",
|
||||
"username": "testuser",
|
||||
"password": "password123",
|
||||
"confirm_password": "password123",
|
||||
})
|
||||
mock_request.cookies = {}
|
||||
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf"})
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.return_value = 0 # No existing operators
|
||||
mock_conn.fetchrow.side_effect = [
|
||||
{"id": 1}, # INSERT RETURNING id
|
||||
{"session_lifetime_days": 90}, # system settings
|
||||
]
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret"
|
||||
with patch("central.gui.routes.hash_password", return_value="hashed"):
|
||||
with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session:
|
||||
mock_session.return_value = ("session_token", datetime.now(), "csrf_token")
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||
result = await setup_operator_submit(
|
||||
mock_request,
|
||||
username="testuser",
|
||||
password="password123",
|
||||
confirm_password="password123",
|
||||
)
|
||||
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
|
||||
with patch("central.gui.routes.hash_password", return_value="hashed_pw"):
|
||||
result = await setup_operator_submit(
|
||||
mock_request,
|
||||
username="testuser",
|
||||
password="password123",
|
||||
confirm_password="password123",
|
||||
)
|
||||
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_when_operator_exists_shows_confirmation(self):
|
||||
"""POST when operator exists returns 200 with confirmation, no insert."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.form = AsyncMock(return_value={
|
||||
"csrf_token": "test_csrf",
|
||||
"username": "testuser",
|
||||
"password": "password123",
|
||||
"confirm_password": "password123",
|
||||
})
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.return_value = 1 # Operator already exists
|
||||
mock_conn.fetchrow.return_value = {"username": "existing_admin"}
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
mock_request.state.csrf_token = "test_csrf"
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret"
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||
result = await setup_operator_submit(
|
||||
mock_request,
|
||||
username="testuser",
|
||||
password="password123",
|
||||
confirm_password="password123",
|
||||
)
|
||||
|
||||
# Should return 200, not 500 or redirect
|
||||
assert result.status_code == 200
|
||||
|
||||
# Should render confirmation state
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["existing_operator"] == {"username": "existing_admin"}
|
||||
|
||||
# Should NOT call write_audit (no insert happened)
|
||||
mock_audit.assert_not_called()
|
||||
|
||||
|
||||
class TestSetupSystemForm:
|
||||
"""Test system settings form (step 2)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_redirects_to_operator(self):
|
||||
"""GET /setup/system without auth redirects to /setup/operator."""
|
||||
async def test_no_wizard_cookie_redirects_to_operator(self):
|
||||
"""GET /setup/system without wizard cookie redirects to /setup/operator."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = None
|
||||
result = await setup_system_form(mock_request)
|
||||
mock_request.cookies = {}
|
||||
|
||||
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
|
||||
result = await setup_system_form(mock_request)
|
||||
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/operator"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticated_returns_form(self):
|
||||
"""GET /setup/system with auth returns the form."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
|
||||
"map_attribution": "© OpenStreetMap contributors",
|
||||
}
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await setup_system_form(mock_request)
|
||||
|
||||
mock_templates.TemplateResponse.assert_called_once()
|
||||
|
||||
|
||||
class TestSetupSystemSubmit:
|
||||
"""Test system settings submission."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_placeholders_shows_error(self):
|
||||
"""POST without {z},{x},{y} placeholders shows error."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
mock_request.state.csrf_token = "test_csrf_token"
|
||||
|
||||
form_data = MagicMock()
|
||||
form_data.get = lambda k, default="": {
|
||||
"csrf_token": "test_csrf_token",
|
||||
"map_tile_url": "https://example.com/tiles",
|
||||
"map_attribution": "Test",
|
||||
}.get(k, default)
|
||||
mock_request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "",
|
||||
"map_attribution": "",
|
||||
}
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await setup_system_submit(mock_request)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert "map_tile_url" in context["errors"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_updates_and_redirects(self):
|
||||
"""POST with valid data updates system and redirects to /setup/keys."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
mock_request.state.csrf_token = "test_csrf_token"
|
||||
|
||||
form_data = MagicMock()
|
||||
form_data.get = lambda k, default="": {
|
||||
"csrf_token": "test_csrf_token",
|
||||
"map_tile_url": "https://example.com/{z}/{x}/{y}.png",
|
||||
"map_attribution": "Test Attribution",
|
||||
}.get(k, default)
|
||||
mock_request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"map_tile_url": "old_url",
|
||||
"map_attribution": "old_attr",
|
||||
}
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||
result = await setup_system_submit(mock_request)
|
||||
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/keys"
|
||||
|
||||
|
||||
class TestSetupKeysForm:
|
||||
"""Test API keys form (step 3)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_redirects_to_operator(self):
|
||||
"""GET /setup/keys without auth redirects to /setup/operator."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = None
|
||||
result = await setup_keys_form(mock_request)
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/operator"
|
||||
|
||||
|
||||
class TestSetupKeysSubmit:
|
||||
"""Test API keys submission."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_action_redirects_to_adapters(self):
|
||||
"""POST with action=next redirects to /setup/adapters."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
mock_request.state.csrf_token = "test_csrf_token"
|
||||
|
||||
form_data = MagicMock()
|
||||
form_data.get = lambda k, default="": {
|
||||
"csrf_token": "test_csrf_token",
|
||||
"action": "next",
|
||||
}.get(k, default)
|
||||
mock_request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
# No need to mock get_pool since action="next" returns before it's called
|
||||
result = await setup_keys_submit(mock_request)
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/adapters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_key_creates_and_rerenders(self):
|
||||
"""POST with action=add creates key and re-renders with success."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
mock_request.state.csrf_token = "test_csrf_token"
|
||||
|
||||
form_data = MagicMock()
|
||||
form_data.get = lambda k, default="": {
|
||||
"csrf_token": "test_csrf_token",
|
||||
"action": "add",
|
||||
"alias": "testkey",
|
||||
"plaintext_key": "secret123",
|
||||
}.get(k, default)
|
||||
mock_request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.side_effect = [
|
||||
None, # No existing key
|
||||
{"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)},
|
||||
]
|
||||
mock_conn.fetch.side_effect = [
|
||||
[], # First list
|
||||
[{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert
|
||||
]
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||
result = await setup_keys_submit(mock_request)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["success"] == "API key 'testkey' added successfully."
|
||||
|
||||
|
||||
class TestSetupAdaptersForm:
|
||||
"""Test adapters configuration form (step 4)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_redirects_to_operator(self):
|
||||
"""GET /setup/adapters without auth redirects to /setup/operator."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = None
|
||||
result = await setup_adapters_form(mock_request)
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/operator"
|
||||
|
||||
|
||||
class TestSetupFinishForm:
|
||||
"""Test finish page (step 5)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_redirects_to_operator(self):
|
||||
"""GET /setup/finish without auth redirects to /setup/operator."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = None
|
||||
result = await setup_finish_form(mock_request)
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/setup/operator"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticated_shows_summary(self):
|
||||
"""GET /setup/finish with auth shows summary."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys
|
||||
mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"}
|
||||
mock_conn.fetch.return_value = [
|
||||
{"name": "nws", "enabled": True, "cadence_s": 300},
|
||||
{"name": "firms", "enabled": False, "cadence_s": 600},
|
||||
]
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await setup_finish_form(mock_request)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["operator_count"] == 1
|
||||
assert context["key_count"] == 2
|
||||
assert len(context["adapters"]) == 2
|
||||
|
||||
|
||||
class TestSetupFinishSubmit:
|
||||
"""Test setup completion."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_setup_complete_and_redirects(self):
|
||||
"""POST /setup/finish marks setup_complete=true and redirects to /."""
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||
mock_request.state.csrf_token = "test_csrf_token"
|
||||
|
||||
# Mock form with CSRF token
|
||||
form_data = MagicMock()
|
||||
form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default)
|
||||
mock_request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||
result = await setup_finish_submit(mock_request)
|
||||
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/"
|
||||
mock_conn.execute.assert_called_once()
|
||||
mock_audit.assert_called_once()
|
||||
|
||||
|
||||
class TestSetupGateMiddlewareWizard:
|
||||
"""Test SetupGateMiddleware with wizard paths."""
|
||||
|
|
@ -570,69 +169,6 @@ class TestSetupGateMiddlewareWizard:
|
|||
response = client.get("/setup/operator")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirects_base_setup_to_wizard_step(self):
|
||||
"""SetupGateMiddleware redirects /setup to appropriate wizard step."""
|
||||
from starlette.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||
mock_conn.fetchval = AsyncMock(return_value=0) # No operators
|
||||
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_conn.__aexit__ = AsyncMock()
|
||||
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||
|
||||
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/setup")
|
||||
async def setup():
|
||||
return {"message": "base setup"}
|
||||
|
||||
@app.get("/setup/operator")
|
||||
async def setup_operator():
|
||||
return {"message": "operator"}
|
||||
|
||||
app.add_middleware(SetupGateMiddleware)
|
||||
client = TestClient(app, follow_redirects=False)
|
||||
|
||||
response = client.get("/setup")
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "/setup/operator"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirects_login_to_setup_when_incomplete(self):
|
||||
"""SetupGateMiddleware redirects /login to /setup when setup_complete=False."""
|
||||
from starlette.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_conn.__aexit__ = AsyncMock()
|
||||
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||
|
||||
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/login")
|
||||
async def login():
|
||||
return {"message": "login"}
|
||||
|
||||
@app.get("/setup")
|
||||
async def setup():
|
||||
return {"message": "setup"}
|
||||
|
||||
app.add_middleware(SetupGateMiddleware)
|
||||
client = TestClient(app, follow_redirects=False)
|
||||
|
||||
response = client.get("/login")
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "/setup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirects_all_setup_paths_when_complete(self):
|
||||
"""SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue