mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
feat(gui): implement first-run setup wizard (1b-8) (#24)
* feat(gui): implement first-run setup wizard (1b-8) Add a 5-step setup wizard that replaces the single-step /setup: 1. Create Operator - create initial operator account 2. System Settings - configure map tile URL and attribution 3. API Keys - optionally add API keys for adapters 4. Configure Adapters - enable/disable adapters with region picker 5. Finish Setup - review and complete setup Key changes: - Update middleware to handle wizard URL structure and step routing - Add wizard routes for each step with proper auth checks - Create new templates using base_wizard.html for consistent styling - Add audit events for system.update and setup.complete - Update tests for new middleware behavior Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(gui): handle CSRF errors on wizard paths Update csrf_exception_handler to re-render wizard forms with error message instead of redirecting to /login when CSRF validation fails. - /setup/operator: re-render with error - /setup/system: re-render with current system values + error - /setup/keys: re-render with current keys list + error - /setup/adapters: re-render with current adapter config + error - /setup/finish: re-render with summary data + error - /setup: redirect to /setup (middleware routes to appropriate step) Add error display to setup_keys.html and setup_finish.html templates. Add 7 new CSRF handler tests for wizard paths. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(gui): region picker render + click-to-draw Bug A: Maps render blank on /setup/adapters for FIRMS and USGS because Leaflet computed zero dimensions before container layout settled. Fix: add setTimeout invalidateSize() after map creation. Bug B: No click-to-draw functionality - only drag corners. Fix: add L.Control.Draw for rectangle drawing with CREATED event handler to replace existing rectangle. Both fixes applied to: - setup_adapters.html (wizard inline JS) - _region_picker.html (standalone edit page) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(gui): handle revisiting /setup/operator after operator created When an operator already exists, /setup/operator now shows a confirmation page instead of the create form. This prevents: - Unique constraint violations on duplicate username - Silent creation of duplicate operators GET /setup/operator: queries config.operators; if any exist, renders confirmation state with existing_operator context. POST /setup/operator: checks operator count before INSERT; if non-zero, renders confirmation state without inserting. Template updated with conditional to show "Operator Already Configured" message when existing_operator is set. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(csrf): replace fastapi-csrf-protect with session-bound CSRF Fixes CSRF race condition where every GET rotated the CSRF token, causing POST failures when users had multiple tabs or slow connections. Changes: - Remove fastapi-csrf-protect dependency - Add session-bound CSRF tokens stored in config.sessions table - Add pre-auth CSRF for unauthenticated routes (/login, /setup/operator) - Add csrf.py module for pre-auth token generation/validation - Update routes to use new CSRF token handling - Add migration 013 to add csrf_token column to sessions The session-bound approach ensures CSRF tokens remain stable for the duration of a session, eliminating the race condition. Note: Route tests (test_wizard.py, test_adapters.py, etc.) need refactoring to mock get_settings() instead of CsrfProtect dependency. Core auth/CSRF handler tests pass (74 tests). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * test(csrf): update test suite for session-bound CSRF tokens - Add CSRF fixtures to conftest.py for pre-auth and session CSRF - Update test_wizard.py: use bypass_pre_auth_csrf and patch_route_settings - Update test_adapters.py: set request.state.csrf_token and form mock data - Update test_api_keys.py: add CSRF token to form data for POST routes - Update test_streams.py: change return_value to side_effect for CSRF support - Update test_region_picker.py: add CSRF token handling - Update test_config_store.py: set CENTRAL_CSRF_SECRET env var in fixture All 285 tests now pass with session-bound CSRF validation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Matt Johnson <mj@k7zvx.com>
This commit is contained in:
parent
96ec88883c
commit
494ad1c799
28 changed files with 2897 additions and 377 deletions
|
|
@ -16,7 +16,6 @@ dependencies = [
|
||||||
"asyncpg>=0.31.0",
|
"asyncpg>=0.31.0",
|
||||||
"cloudevents>=2.0.0",
|
"cloudevents>=2.0.0",
|
||||||
"cryptography>=44.0.0",
|
"cryptography>=44.0.0",
|
||||||
"fastapi-csrf-protect>=0.4.0",
|
|
||||||
"fastapi>=0.115.0",
|
"fastapi>=0.115.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"nats-py>=2.14.0",
|
"nats-py>=2.14.0",
|
||||||
|
|
|
||||||
9
sql/migrations/013_add_session_csrf_token.sql
Normal file
9
sql/migrations/013_add_session_csrf_token.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
-- Add CSRF token column to sessions table
|
||||||
|
-- Session-bound CSRF tokens prevent race conditions from cookie rotation
|
||||||
|
|
||||||
|
ALTER TABLE config.sessions
|
||||||
|
ADD COLUMN csrf_token TEXT NOT NULL
|
||||||
|
DEFAULT encode(gen_random_bytes(32), 'hex');
|
||||||
|
|
||||||
|
-- Comment
|
||||||
|
COMMENT ON COLUMN config.sessions.csrf_token IS 'Session-bound CSRF token for synchronizer token pattern';
|
||||||
|
|
@ -29,23 +29,6 @@ _cleanup_task: asyncio.Task | None = None
|
||||||
_app: FastAPI | None = None
|
_app: FastAPI | None = None
|
||||||
|
|
||||||
|
|
||||||
def _configure_csrf() -> None:
|
|
||||||
"""Configure CSRF protection. Must be called before app starts."""
|
|
||||||
from fastapi_csrf_protect import CsrfProtect
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from central.bootstrap_config import get_settings
|
|
||||||
|
|
||||||
class CsrfSettings(BaseModel):
|
|
||||||
secret_key: str
|
|
||||||
token_location: str = "body"
|
|
||||||
token_key: str = "csrf_token"
|
|
||||||
|
|
||||||
@CsrfProtect.load_config
|
|
||||||
def get_csrf_config():
|
|
||||||
settings = get_settings()
|
|
||||||
return CsrfSettings(secret_key=settings.csrf_secret)
|
|
||||||
|
|
||||||
|
|
||||||
async def _session_cleanup_loop() -> None:
|
async def _session_cleanup_loop() -> None:
|
||||||
"""Periodically clean up expired sessions."""
|
"""Periodically clean up expired sessions."""
|
||||||
global _shutdown_event
|
global _shutdown_event
|
||||||
|
|
@ -117,9 +100,6 @@ def _create_app() -> FastAPI:
|
||||||
from central.gui.middleware import SessionMiddleware, SetupGateMiddleware
|
from central.gui.middleware import SessionMiddleware, SetupGateMiddleware
|
||||||
from central.gui.routes import router
|
from central.gui.routes import router
|
||||||
|
|
||||||
# Configure CSRF before creating app
|
|
||||||
_configure_csrf()
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Central GUI",
|
title="Central GUI",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
|
|
@ -137,45 +117,214 @@ def _create_app() -> FastAPI:
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
# CSRF exception handler - return friendly error instead of 500
|
# CSRF exception handler - return friendly error instead of 500
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
from central.gui.auth import CsrfValidationError
|
||||||
|
from central.gui.csrf import generate_pre_auth_csrf, set_pre_auth_csrf_cookie
|
||||||
|
from central.bootstrap_config import get_settings
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
@app.exception_handler(CsrfProtectError)
|
@app.exception_handler(CsrfValidationError)
|
||||||
async def csrf_exception_handler(request, exc: CsrfProtectError):
|
async def csrf_exception_handler(request, exc: CsrfValidationError):
|
||||||
from fastapi_csrf_protect import CsrfProtect
|
from central.gui.db import get_pool
|
||||||
|
|
||||||
csrf_protect = CsrfProtect()
|
settings = get_settings()
|
||||||
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
|
# For pre-auth paths, generate a new pre-auth token
|
||||||
|
# For session paths, we'll just show the error (session token stays valid)
|
||||||
|
csrf_token, signed_token = generate_pre_auth_csrf(settings.csrf_secret)
|
||||||
|
error_msg = "Your session expired. Please try again."
|
||||||
|
|
||||||
if request.url.path == "/login":
|
if request.url.path == "/login":
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
request=request,
|
request=request,
|
||||||
name="login.html",
|
name="login.html",
|
||||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
context={"csrf_token": csrf_token, "error": error_msg},
|
||||||
)
|
)
|
||||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
elif request.url.path == "/setup":
|
elif request.url.path == "/setup":
|
||||||
|
# /setup is a redirect path now, not a form
|
||||||
|
return RedirectResponse("/setup", status_code=302)
|
||||||
|
|
||||||
|
elif request.url.path == "/setup/operator":
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
request=request,
|
request=request,
|
||||||
name="setup.html",
|
name="setup_operator.html",
|
||||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
context={"csrf_token": csrf_token, "error": error_msg, "form_data": None},
|
||||||
)
|
)
|
||||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
elif request.url.path == "/setup/system":
|
||||||
|
pool = get_pool()
|
||||||
|
system = {
|
||||||
|
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
|
||||||
|
"map_attribution": "© OpenStreetMap contributors",
|
||||||
|
}
|
||||||
|
if pool:
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
|
||||||
|
)
|
||||||
|
if row:
|
||||||
|
system = {
|
||||||
|
"map_tile_url": row["map_tile_url"],
|
||||||
|
"map_attribution": row["map_attribution"],
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
response = templates.TemplateResponse(
|
||||||
|
request=request,
|
||||||
|
name="setup_system.html",
|
||||||
|
context={
|
||||||
|
"csrf_token": csrf_token,
|
||||||
|
"error": error_msg,
|
||||||
|
"errors": None,
|
||||||
|
"form_data": None,
|
||||||
|
"system": system,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
elif request.url.path == "/setup/keys":
|
||||||
|
pool = get_pool()
|
||||||
|
keys = []
|
||||||
|
if pool:
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"SELECT alias, created_at FROM config.api_keys ORDER BY alias"
|
||||||
|
)
|
||||||
|
keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
response = templates.TemplateResponse(
|
||||||
|
request=request,
|
||||||
|
name="setup_keys.html",
|
||||||
|
context={
|
||||||
|
"csrf_token": csrf_token,
|
||||||
|
"keys": keys,
|
||||||
|
"errors": None,
|
||||||
|
"form_data": None,
|
||||||
|
"success": None,
|
||||||
|
"error": error_msg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
elif request.url.path == "/setup/adapters":
|
||||||
|
pool = get_pool()
|
||||||
|
adapters = []
|
||||||
|
api_keys = []
|
||||||
|
tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||||
|
tile_attribution = "© OpenStreetMap contributors"
|
||||||
|
if pool:
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name"
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
settings = row["settings"] or {}
|
||||||
|
adapters.append({
|
||||||
|
"name": row["name"],
|
||||||
|
"enabled": row["enabled"],
|
||||||
|
"cadence_s": row["cadence_s"],
|
||||||
|
"settings": settings,
|
||||||
|
})
|
||||||
|
key_rows = await conn.fetch(
|
||||||
|
"SELECT alias FROM config.api_keys ORDER BY alias"
|
||||||
|
)
|
||||||
|
api_keys = [{"alias": k["alias"]} for k in key_rows]
|
||||||
|
sys_row = await conn.fetchrow(
|
||||||
|
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
|
||||||
|
)
|
||||||
|
if sys_row:
|
||||||
|
tile_url = sys_row["map_tile_url"]
|
||||||
|
tile_attribution = sys_row["map_attribution"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Import helper functions for valid values
|
||||||
|
from central.gui.routes import _get_valid_satellites, _get_valid_feeds
|
||||||
|
|
||||||
|
response = templates.TemplateResponse(
|
||||||
|
request=request,
|
||||||
|
name="setup_adapters.html",
|
||||||
|
context={
|
||||||
|
"csrf_token": csrf_token,
|
||||||
|
"adapters": adapters,
|
||||||
|
"api_keys": api_keys,
|
||||||
|
"valid_satellites": _get_valid_satellites(),
|
||||||
|
"valid_feeds": sorted(_get_valid_feeds()),
|
||||||
|
"tile_url": tile_url,
|
||||||
|
"tile_attribution": tile_attribution,
|
||||||
|
"error": error_msg,
|
||||||
|
"errors": None,
|
||||||
|
"form_data": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
elif request.url.path == "/setup/finish":
|
||||||
|
pool = get_pool()
|
||||||
|
operator_count = 0
|
||||||
|
key_count = 0
|
||||||
|
system = {"map_tile_url": ""}
|
||||||
|
adapters = []
|
||||||
|
if pool:
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
|
||||||
|
key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys")
|
||||||
|
sys_row = await conn.fetchrow(
|
||||||
|
"SELECT map_tile_url FROM config.system WHERE id = true"
|
||||||
|
)
|
||||||
|
if sys_row:
|
||||||
|
system = {"map_tile_url": sys_row["map_tile_url"]}
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name"
|
||||||
|
)
|
||||||
|
adapters = [
|
||||||
|
{"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
response = templates.TemplateResponse(
|
||||||
|
request=request,
|
||||||
|
name="setup_finish.html",
|
||||||
|
context={
|
||||||
|
"csrf_token": csrf_token,
|
||||||
|
"operator_count": operator_count,
|
||||||
|
"key_count": key_count,
|
||||||
|
"system": system,
|
||||||
|
"adapters": adapters,
|
||||||
|
"error": error_msg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
|
return response
|
||||||
|
|
||||||
elif request.url.path == "/logout":
|
elif request.url.path == "/logout":
|
||||||
return RedirectResponse("/login", status_code=302)
|
return RedirectResponse("/login", status_code=302)
|
||||||
|
|
||||||
elif request.url.path == "/change-password":
|
elif request.url.path == "/change-password":
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
request=request,
|
request=request,
|
||||||
name="change_password.html",
|
name="change_password.html",
|
||||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
context={"csrf_token": csrf_token, "error": error_msg},
|
||||||
)
|
)
|
||||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
set_pre_auth_csrf_cookie(response, signed_token)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
elif request.url.path.startswith("/adapters/"):
|
elif request.url.path.startswith("/adapters/"):
|
||||||
# Redirect back to adapters list
|
# Redirect back to adapters list
|
||||||
return RedirectResponse("/adapters", status_code=302)
|
return RedirectResponse("/adapters", status_code=302)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback: redirect to login
|
# Fallback: redirect to login
|
||||||
return RedirectResponse("/login", status_code=302)
|
return RedirectResponse("/login", status_code=302)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ STREAM_UPDATE = "stream.update"
|
||||||
API_KEY_CREATE = "api_key.create"
|
API_KEY_CREATE = "api_key.create"
|
||||||
API_KEY_ROTATE = "api_key.rotate"
|
API_KEY_ROTATE = "api_key.rotate"
|
||||||
API_KEY_DELETE = "api_key.delete"
|
API_KEY_DELETE = "api_key.delete"
|
||||||
|
SYSTEM_UPDATE = "system.update"
|
||||||
|
SETUP_COMPLETE = "setup.complete"
|
||||||
|
|
||||||
|
|
||||||
async def write_audit(
|
async def write_audit(
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,11 @@ from argon2.exceptions import VerifyMismatchError
|
||||||
_hasher = PasswordHasher()
|
_hasher = PasswordHasher()
|
||||||
|
|
||||||
|
|
||||||
|
class CsrfValidationError(Exception):
|
||||||
|
"""Raised when CSRF token validation fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Operator:
|
class Operator:
|
||||||
"""Operator account."""
|
"""Operator account."""
|
||||||
|
|
@ -46,39 +51,46 @@ def generate_token() -> str:
|
||||||
return secrets.token_urlsafe(32)
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csrf_token() -> str:
|
||||||
|
"""Generate a cryptographically secure CSRF token."""
|
||||||
|
return secrets.token_hex(32)
|
||||||
|
|
||||||
|
|
||||||
async def create_session(
|
async def create_session(
|
||||||
conn: Any, # asyncpg.Connection
|
conn: Any, # asyncpg.Connection
|
||||||
operator_id: int,
|
operator_id: int,
|
||||||
lifetime_days: int,
|
lifetime_days: int,
|
||||||
) -> tuple[str, datetime]:
|
) -> tuple[str, datetime, str]:
|
||||||
"""Create a new session for an operator.
|
"""Create a new session for an operator.
|
||||||
|
|
||||||
Returns (token, expires_at).
|
Returns (token, expires_at, csrf_token).
|
||||||
"""
|
"""
|
||||||
token = generate_token()
|
token = generate_token()
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=lifetime_days)
|
expires_at = datetime.now(timezone.utc) + timedelta(days=lifetime_days)
|
||||||
|
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO config.sessions (token, operator_id, expires_at)
|
INSERT INTO config.sessions (token, operator_id, expires_at, csrf_token)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3, $4)
|
||||||
""",
|
""",
|
||||||
token,
|
token,
|
||||||
operator_id,
|
operator_id,
|
||||||
expires_at,
|
expires_at,
|
||||||
|
csrf_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
return token, expires_at
|
return token, expires_at, csrf_token
|
||||||
|
|
||||||
|
|
||||||
async def get_session(conn: Any, token: str) -> Operator | None:
|
async def get_session(conn: Any, token: str) -> tuple[Operator, str] | None:
|
||||||
"""Look up a session and return the associated operator.
|
"""Look up a session and return the associated operator and csrf_token.
|
||||||
|
|
||||||
Returns None if token is invalid or expired.
|
Returns (Operator, csrf_token) or None if token is invalid or expired.
|
||||||
"""
|
"""
|
||||||
row = await conn.fetchrow(
|
row = await conn.fetchrow(
|
||||||
"""
|
"""
|
||||||
SELECT o.id, o.username, o.created_at, o.password_changed_at
|
SELECT o.id, o.username, o.created_at, o.password_changed_at, s.csrf_token
|
||||||
FROM config.sessions s
|
FROM config.sessions s
|
||||||
JOIN config.operators o ON s.operator_id = o.id
|
JOIN config.operators o ON s.operator_id = o.id
|
||||||
WHERE s.token = $1 AND s.expires_at > now()
|
WHERE s.token = $1 AND s.expires_at > now()
|
||||||
|
|
@ -89,12 +101,14 @@ async def get_session(conn: Any, token: str) -> Operator | None:
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return Operator(
|
operator = Operator(
|
||||||
id=row["id"],
|
id=row["id"],
|
||||||
username=row["username"],
|
username=row["username"],
|
||||||
created_at=row["created_at"],
|
created_at=row["created_at"],
|
||||||
password_changed_at=row.get("password_changed_at"),
|
password_changed_at=row.get("password_changed_at"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return operator, row["csrf_token"]
|
||||||
|
|
||||||
|
|
||||||
async def delete_session(conn: Any, token: str) -> None:
|
async def delete_session(conn: Any, token: str) -> None:
|
||||||
|
|
|
||||||
72
src/central/gui/csrf.py
Normal file
72
src/central/gui/csrf.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""Pre-auth CSRF protection for login and setup/operator pages.
|
||||||
|
|
||||||
|
These routes cannot use session-bound CSRF because no session exists yet.
|
||||||
|
Uses a simple cookie-based pattern with short-lived tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
# 10 minute max age for pre-auth CSRF tokens
|
||||||
|
PRE_AUTH_CSRF_MAX_AGE = 600
|
||||||
|
PRE_AUTH_CSRF_COOKIE = "central_preauth_csrf"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_serializer(secret_key: str) -> URLSafeTimedSerializer:
|
||||||
|
"""Get a timed serializer for CSRF tokens."""
|
||||||
|
return URLSafeTimedSerializer(secret_key, salt="preauth-csrf")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pre_auth_csrf(secret_key: str) -> tuple[str, str]:
|
||||||
|
"""Generate a pre-auth CSRF token pair.
|
||||||
|
|
||||||
|
Returns (plain_token, signed_token).
|
||||||
|
The plain_token goes in the form, signed_token goes in the cookie.
|
||||||
|
"""
|
||||||
|
plain_token = secrets.token_hex(32)
|
||||||
|
serializer = _get_serializer(secret_key)
|
||||||
|
signed_token = serializer.dumps(plain_token)
|
||||||
|
return plain_token, signed_token
|
||||||
|
|
||||||
|
|
||||||
|
def set_pre_auth_csrf_cookie(response: Response, signed_token: str) -> None:
|
||||||
|
"""Set the pre-auth CSRF cookie on a response."""
|
||||||
|
response.set_cookie(
|
||||||
|
PRE_AUTH_CSRF_COOKIE,
|
||||||
|
signed_token,
|
||||||
|
max_age=PRE_AUTH_CSRF_MAX_AGE,
|
||||||
|
path="/",
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_pre_auth_csrf(
|
||||||
|
request: Request,
|
||||||
|
form_token: str,
|
||||||
|
secret_key: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Validate a pre-auth CSRF token.
|
||||||
|
|
||||||
|
Returns True if valid, False otherwise.
|
||||||
|
"""
|
||||||
|
cookie_value = request.cookies.get(PRE_AUTH_CSRF_COOKIE)
|
||||||
|
if not cookie_value or not form_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
serializer = _get_serializer(secret_key)
|
||||||
|
try:
|
||||||
|
expected_token = serializer.loads(cookie_value, max_age=PRE_AUTH_CSRF_MAX_AGE)
|
||||||
|
return secrets.compare_digest(form_token, expected_token)
|
||||||
|
except (BadSignature, SignatureExpired):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def unset_pre_auth_csrf_cookie(response: Response) -> None:
|
||||||
|
"""Remove the pre-auth CSRF cookie."""
|
||||||
|
response.delete_cookie(PRE_AUTH_CSRF_COOKIE, path="/")
|
||||||
|
|
@ -12,11 +12,10 @@ from central.gui.db import get_pool
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Paths that don't require setup to be complete
|
# Paths that don't require setup to be complete
|
||||||
SETUP_EXEMPT_PATHS = {"/setup", "/health"}
|
SETUP_EXEMPT_PREFIXES = ("/static/", "/setup")
|
||||||
SETUP_EXEMPT_PREFIXES = ("/static/",)
|
|
||||||
|
|
||||||
# Paths that don't require authentication
|
# Paths that don't require authentication
|
||||||
AUTH_EXEMPT_PATHS = {"/setup", "/login", "/health"}
|
AUTH_EXEMPT_PATHS = {"/setup/operator", "/login", "/health"}
|
||||||
AUTH_EXEMPT_PREFIXES = ("/static/",)
|
AUTH_EXEMPT_PREFIXES = ("/static/",)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,6 +29,35 @@ def _is_exempt(path: str, exempt_paths: set, exempt_prefixes: tuple) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_wizard_redirect_step(conn) -> str:
|
||||||
|
"""Determine which wizard step to redirect to based on DB state."""
|
||||||
|
# Check if any operators exist
|
||||||
|
op_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
|
||||||
|
if op_count == 0:
|
||||||
|
return "/setup/operator"
|
||||||
|
|
||||||
|
# Check if system settings have been configured (map_tile_url not default)
|
||||||
|
sys_row = await conn.fetchrow(
|
||||||
|
"SELECT map_tile_url FROM config.system WHERE id = true"
|
||||||
|
)
|
||||||
|
default_tile = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||||
|
if sys_row is None or sys_row["map_tile_url"] == default_tile:
|
||||||
|
return "/setup/system"
|
||||||
|
|
||||||
|
# Keys step is optional, so check adapters have been reviewed
|
||||||
|
# We consider adapters reviewed if any adapter has a non-null updated_at
|
||||||
|
# (meaning it was explicitly saved during setup)
|
||||||
|
adapters_touched = await conn.fetchval(
|
||||||
|
"SELECT COUNT(*) FROM config.adapters WHERE updated_at IS NOT NULL"
|
||||||
|
)
|
||||||
|
if adapters_touched == 0:
|
||||||
|
# Go to keys first, then adapters
|
||||||
|
return "/setup/keys"
|
||||||
|
|
||||||
|
# All steps done, go to finish
|
||||||
|
return "/setup/finish"
|
||||||
|
|
||||||
|
|
||||||
class SetupGateMiddleware(BaseHTTPMiddleware):
|
class SetupGateMiddleware(BaseHTTPMiddleware):
|
||||||
"""Redirect to /setup if setup is not complete."""
|
"""Redirect to /setup if setup is not complete."""
|
||||||
|
|
||||||
|
|
@ -55,25 +83,44 @@ class SetupGateMiddleware(BaseHTTPMiddleware):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
if not setup_complete:
|
if not setup_complete:
|
||||||
# Setup not complete - only allow exempt paths
|
# Setup not complete - only allow setup paths and static/health
|
||||||
if not _is_exempt(path, SETUP_EXEMPT_PATHS, SETUP_EXEMPT_PREFIXES):
|
if path.startswith("/setup"):
|
||||||
|
# Allow all /setup/* paths (handler will enforce auth)
|
||||||
|
# But /setup with no subpath should redirect to appropriate step
|
||||||
|
if path == "/setup" or path == "/setup/":
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
redirect_step = await _get_wizard_redirect_step(conn)
|
||||||
|
return RedirectResponse(url=redirect_step, status_code=302)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to determine wizard step", exc_info=True)
|
||||||
|
return RedirectResponse(url="/setup/operator", status_code=302)
|
||||||
|
return await call_next(request)
|
||||||
|
elif path == "/health" or path.startswith("/static/"):
|
||||||
|
return await call_next(request)
|
||||||
|
elif path == "/login":
|
||||||
|
# During setup, login redirects to /setup
|
||||||
|
return RedirectResponse(url="/setup", status_code=302)
|
||||||
|
else:
|
||||||
|
# All other paths redirect to /setup
|
||||||
return RedirectResponse(url="/setup", status_code=302)
|
return RedirectResponse(url="/setup", status_code=302)
|
||||||
else:
|
else:
|
||||||
# Setup complete - redirect /setup to /
|
# Setup complete - redirect /setup* to /
|
||||||
if path == "/setup":
|
if path.startswith("/setup"):
|
||||||
return RedirectResponse(url="/", status_code=302)
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
class SessionMiddleware(BaseHTTPMiddleware):
|
class SessionMiddleware(BaseHTTPMiddleware):
|
||||||
"""Load session from cookie and attach operator to request.state."""
|
"""Load session from cookie and attach operator + csrf_token to request.state."""
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response:
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
|
||||||
# Initialize operator to None
|
# Initialize state
|
||||||
request.state.operator = None
|
request.state.operator = None
|
||||||
|
request.state.csrf_token = None
|
||||||
|
|
||||||
# Try to load session from cookie
|
# Try to load session from cookie
|
||||||
session_token = request.cookies.get("central_session")
|
session_token = request.cookies.get("central_session")
|
||||||
|
|
@ -82,11 +129,15 @@ class SessionMiddleware(BaseHTTPMiddleware):
|
||||||
if pool is not None:
|
if pool is not None:
|
||||||
try:
|
try:
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
operator = await get_session(conn, session_token)
|
result = await get_session(conn, session_token)
|
||||||
request.state.operator = operator
|
if result is not None:
|
||||||
|
operator, csrf_token = result
|
||||||
|
request.state.operator = operator
|
||||||
|
request.state.csrf_token = csrf_token
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to load session", exc_info=True)
|
logger.warning("Failed to load session", exc_info=True)
|
||||||
request.state.operator = None
|
request.state.operator = None
|
||||||
|
request.state.csrf_token = None
|
||||||
|
|
||||||
# Check if auth is required
|
# Check if auth is required
|
||||||
if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES):
|
if not _is_exempt(path, AUTH_EXEMPT_PATHS, AUTH_EXEMPT_PREFIXES):
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -59,6 +59,10 @@
|
||||||
maxZoom: 18
|
maxZoom: 18
|
||||||
}).addTo(map);
|
}).addTo(map);
|
||||||
|
|
||||||
|
// Ensure map renders correctly even if container has not
|
||||||
|
// finished laying out at init time
|
||||||
|
setTimeout(function() { map.invalidateSize(); }, 100);
|
||||||
|
|
||||||
// Create initial rectangle
|
// Create initial rectangle
|
||||||
const bounds = L.latLngBounds(
|
const bounds = L.latLngBounds(
|
||||||
L.latLng(savedSouth, savedWest),
|
L.latLng(savedSouth, savedWest),
|
||||||
|
|
@ -69,11 +73,34 @@
|
||||||
map.fitBounds(bounds.pad(0.1));
|
map.fitBounds(bounds.pad(0.1));
|
||||||
|
|
||||||
// Create editable rectangle
|
// Create editable rectangle
|
||||||
const rectangle = L.rectangle(bounds, {
|
let rectangle = L.rectangle(bounds, {
|
||||||
color: '#3388ff',
|
color: '#3388ff',
|
||||||
weight: 2,
|
weight: 2,
|
||||||
fillOpacity: 0.2
|
fillOpacity: 0.2
|
||||||
}).addTo(map);
|
});
|
||||||
|
|
||||||
|
// Set up Leaflet.draw for click-to-draw
|
||||||
|
const drawnItems = new L.FeatureGroup();
|
||||||
|
drawnItems.addLayer(rectangle);
|
||||||
|
map.addLayer(drawnItems);
|
||||||
|
|
||||||
|
const drawControl = new L.Control.Draw({
|
||||||
|
draw: {
|
||||||
|
rectangle: { shapeOptions: { color: '#3388ff', weight: 2,
|
||||||
|
fillOpacity: 0.2 } },
|
||||||
|
polyline: false,
|
||||||
|
polygon: false,
|
||||||
|
circle: false,
|
||||||
|
marker: false,
|
||||||
|
circlemarker: false
|
||||||
|
},
|
||||||
|
edit: {
|
||||||
|
featureGroup: drawnItems,
|
||||||
|
edit: false,
|
||||||
|
remove: false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
map.addControl(drawControl);
|
||||||
|
|
||||||
// Make rectangle editable
|
// Make rectangle editable
|
||||||
rectangle.editing.enable();
|
rectangle.editing.enable();
|
||||||
|
|
@ -96,13 +123,33 @@
|
||||||
// Listen for rectangle edit events
|
// Listen for rectangle edit events
|
||||||
rectangle.on('edit', updateInputs);
|
rectangle.on('edit', updateInputs);
|
||||||
|
|
||||||
|
// When user draws a new rectangle, replace the existing one
|
||||||
|
map.on(L.Draw.Event.CREATED, function(e) {
|
||||||
|
drawnItems.clearLayers();
|
||||||
|
rectangle = e.layer;
|
||||||
|
rectangle.setStyle({ color: '#3388ff', weight: 2,
|
||||||
|
fillOpacity: 0.2 });
|
||||||
|
drawnItems.addLayer(rectangle);
|
||||||
|
rectangle.editing.enable();
|
||||||
|
rectangle.on('edit', updateInputs);
|
||||||
|
updateInputs();
|
||||||
|
});
|
||||||
|
|
||||||
// Reset button
|
// Reset button
|
||||||
document.getElementById('region-reset-btn').addEventListener('click', function() {
|
document.getElementById('region-reset-btn').addEventListener('click', function() {
|
||||||
const originalBounds = L.latLngBounds(
|
const originalBounds = L.latLngBounds(
|
||||||
L.latLng(savedSouth, savedWest),
|
L.latLng(savedSouth, savedWest),
|
||||||
L.latLng(savedNorth, savedEast)
|
L.latLng(savedNorth, savedEast)
|
||||||
);
|
);
|
||||||
rectangle.setBounds(originalBounds);
|
drawnItems.clearLayers();
|
||||||
|
rectangle = L.rectangle(originalBounds, {
|
||||||
|
color: '#3388ff',
|
||||||
|
weight: 2,
|
||||||
|
fillOpacity: 0.2
|
||||||
|
});
|
||||||
|
drawnItems.addLayer(rectangle);
|
||||||
|
rectangle.editing.enable();
|
||||||
|
rectangle.on('edit', updateInputs);
|
||||||
updateInputs();
|
updateInputs();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
6
src/central/gui/templates/_wizard_header.html
Normal file
6
src/central/gui/templates/_wizard_header.html
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
<article style="margin-bottom: 2rem;">
|
||||||
|
<header>
|
||||||
|
<strong>Step {{ step }} of 5</strong> — {{ step_name }}
|
||||||
|
</header>
|
||||||
|
<progress value="{{ step }}" max="5" style="margin-bottom: 0;"></progress>
|
||||||
|
</article>
|
||||||
24
src/central/gui/templates/base_wizard.html
Normal file
24
src/central/gui/templates/base_wizard.html
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" data-theme="light">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{% block title %}Central - Setup{% endblock %}</title>
|
||||||
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css">
|
||||||
|
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
||||||
|
{% block head %}{% endblock %}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav class="container">
|
||||||
|
<ul>
|
||||||
|
<li><strong>Central</strong></li>
|
||||||
|
</ul>
|
||||||
|
<ul>
|
||||||
|
<li>Setup Wizard</li>
|
||||||
|
</ul>
|
||||||
|
</nav>
|
||||||
|
<main class="container">
|
||||||
|
{% block content %}{% endblock %}
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
256
src/central/gui/templates/setup_adapters.html
Normal file
256
src/central/gui/templates/setup_adapters.html
Normal file
|
|
@ -0,0 +1,256 @@
|
||||||
|
{% extends "base_wizard.html" %}
|
||||||
|
|
||||||
|
{% block title %}Central - Configure Adapters{% endblock %}
|
||||||
|
|
||||||
|
{% block head %}
|
||||||
|
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" integrity="sha256-p4NxAoJBhIIN+hmNHrzRCf9tD/miZyoHS5obTRR9BMY=" crossorigin="">
|
||||||
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css" integrity="sha512-gc3xjCmIy673V6MyOAZhIW93xhM9ei1I+gLbmFjUHIjocENRsLX/QUE1htk5q1XV2D/iie/VQ8DXI6Uj8GB1Og==" crossorigin="anonymous">
|
||||||
|
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js" integrity="sha256-20nQCchB9co0qIjJZRGuk2/Z9VM+kNiyxNV1lvTlZBo=" crossorigin=""></script>
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js" integrity="sha512-ozq8xQKq6urvuU6jNgkfqAmT7jKN2XumbrX1JiB3TnF7tI48DPI4Ber9dLJ0ikXiRg9G9Vl2jXwqjZ5LDGQ3g==" crossorigin="anonymous"></script>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
{% with step=4, step_name="Configure Adapters" %}
|
||||||
|
{% include "_wizard_header.html" %}
|
||||||
|
{% endwith %}
|
||||||
|
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>Configure Adapters</h1>
|
||||||
|
<p>Enable and configure data source adapters. Each adapter polls an external API and normalizes events.</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<form action="/setup/adapters" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
|
||||||
|
{% for adapter in adapters %}
|
||||||
|
<details open style="margin-bottom: 2rem;">
|
||||||
|
<summary><strong>{{ adapter.name }}</strong></summary>
|
||||||
|
|
||||||
|
<div style="padding: 1rem; border-left: 3px solid var(--pico-primary);">
|
||||||
|
<label>
|
||||||
|
<input type="checkbox" name="{{ adapter.name }}_enabled"
|
||||||
|
{% if form_data and form_data.get(adapter.name + '_enabled') %}checked
|
||||||
|
{% elif not form_data and adapter.enabled %}checked{% endif %}>
|
||||||
|
Enabled
|
||||||
|
</label>
|
||||||
|
{% if errors and errors.get(adapter.name + '_enabled') %}
|
||||||
|
<small style="color: var(--pico-color-red-500); display: block;">{{ errors[adapter.name + '_enabled'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<label for="{{ adapter.name }}_cadence_s">Cadence (seconds)</label>
|
||||||
|
<input type="number" id="{{ adapter.name }}_cadence_s" name="{{ adapter.name }}_cadence_s"
|
||||||
|
value="{{ form_data.get(adapter.name + '_cadence_s') if form_data else adapter.cadence_s }}"
|
||||||
|
min="60" max="3600">
|
||||||
|
{% if errors and errors.get(adapter.name + '_cadence_s') %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_cadence_s'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if adapter.name == 'nws' %}
|
||||||
|
<label for="{{ adapter.name }}_contact_email">Contact Email</label>
|
||||||
|
<input type="email" id="{{ adapter.name }}_contact_email" name="{{ adapter.name }}_contact_email"
|
||||||
|
value="{{ form_data.get(adapter.name + '_contact_email') if form_data else adapter.settings.contact_email }}">
|
||||||
|
{% if errors and errors.get(adapter.name + '_contact_email') %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_contact_email'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if adapter.name == 'firms' %}
|
||||||
|
<label for="{{ adapter.name }}_api_key_alias">API Key Alias</label>
|
||||||
|
<select id="{{ adapter.name }}_api_key_alias" name="{{ adapter.name }}_api_key_alias">
|
||||||
|
<option value="">(none)</option>
|
||||||
|
{% for key in api_keys %}
|
||||||
|
<option value="{{ key.alias }}"
|
||||||
|
{% if (form_data.get(adapter.name + '_api_key_alias') if form_data else adapter.settings.api_key_alias) == key.alias %}selected{% endif %}>
|
||||||
|
{{ key.alias }}
|
||||||
|
</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% if errors and errors.get(adapter.name + '_api_key_alias') %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_api_key_alias'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<label>Satellites</label>
|
||||||
|
{% for sat in valid_satellites %}
|
||||||
|
<label style="display: inline-block; margin-right: 1rem;">
|
||||||
|
<input type="checkbox" name="{{ adapter.name }}_satellites" value="{{ sat }}"
|
||||||
|
{% if sat in (form_data.getlist(adapter.name + '_satellites') if form_data else adapter.settings.satellites or []) %}checked{% endif %}>
|
||||||
|
{{ sat }}
|
||||||
|
</label>
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if adapter.name == 'usgs_quake' %}
|
||||||
|
<label for="{{ adapter.name }}_feed">Feed</label>
|
||||||
|
<select id="{{ adapter.name }}_feed" name="{{ adapter.name }}_feed">
|
||||||
|
{% for f in valid_feeds %}
|
||||||
|
<option value="{{ f }}"
|
||||||
|
{% if (form_data.get(adapter.name + '_feed') if form_data else adapter.settings.feed) == f %}selected{% endif %}>
|
||||||
|
{{ f }}
|
||||||
|
</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
{% if errors and errors.get(adapter.name + '_feed') %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_feed'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<h4>Region</h4>
|
||||||
|
{% set region = form_data if form_data else adapter.settings.region %}
|
||||||
|
<div id="region-picker-{{ adapter.name }}"
|
||||||
|
data-adapter="{{ adapter.name }}"
|
||||||
|
data-north="{{ form_data.get(adapter.name + '_region_north') if form_data else (adapter.settings.region.north if adapter.settings.region else 49.5) }}"
|
||||||
|
data-south="{{ form_data.get(adapter.name + '_region_south') if form_data else (adapter.settings.region.south if adapter.settings.region else 31.0) }}"
|
||||||
|
data-east="{{ form_data.get(adapter.name + '_region_east') if form_data else (adapter.settings.region.east if adapter.settings.region else -102.0) }}"
|
||||||
|
data-west="{{ form_data.get(adapter.name + '_region_west') if form_data else (adapter.settings.region.west if adapter.settings.region else -124.5) }}"
|
||||||
|
data-tile-url="{{ tile_url }}"
|
||||||
|
data-tile-attr="{{ tile_attribution }}">
|
||||||
|
|
||||||
|
<div id="region-map-{{ adapter.name }}" style="height: 300px; margin-bottom: 1rem;"></div>
|
||||||
|
|
||||||
|
<div class="grid">
|
||||||
|
<div>
|
||||||
|
<label>North</label>
|
||||||
|
<input type="number" name="{{ adapter.name }}_region_north" step="0.0001" min="-90" max="90" readonly
|
||||||
|
value="{{ form_data.get(adapter.name + '_region_north') if form_data else (adapter.settings.region.north if adapter.settings.region else 49.5) }}">
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label>South</label>
|
||||||
|
<input type="number" name="{{ adapter.name }}_region_south" step="0.0001" min="-90" max="90" readonly
|
||||||
|
value="{{ form_data.get(adapter.name + '_region_south') if form_data else (adapter.settings.region.south if adapter.settings.region else 31.0) }}">
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label>East</label>
|
||||||
|
<input type="number" name="{{ adapter.name }}_region_east" step="0.0001" min="-180" max="180" readonly
|
||||||
|
value="{{ form_data.get(adapter.name + '_region_east') if form_data else (adapter.settings.region.east if adapter.settings.region else -102.0) }}">
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label>West</label>
|
||||||
|
<input type="number" name="{{ adapter.name }}_region_west" step="0.0001" min="-180" max="180" readonly
|
||||||
|
value="{{ form_data.get(adapter.name + '_region_west') if form_data else (adapter.settings.region.west if adapter.settings.region else -124.5) }}">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% if errors and errors.get(adapter.name + '_region') %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_region'] }}</small>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
<div style="display: flex; gap: 1rem; margin-top: 1rem;">
|
||||||
|
<a href="/setup/keys" role="button" class="outline">← Back</a>
|
||||||
|
<button type="submit">Next →</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</article>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
const adapters = ['nws', 'firms', 'usgs_quake'];
|
||||||
|
|
||||||
|
adapters.forEach(function(adapterName) {
|
||||||
|
const container = document.getElementById('region-picker-' + adapterName);
|
||||||
|
if (!container) return;
|
||||||
|
|
||||||
|
const savedNorth = parseFloat(container.dataset.north);
|
||||||
|
const savedSouth = parseFloat(container.dataset.south);
|
||||||
|
const savedEast = parseFloat(container.dataset.east);
|
||||||
|
const savedWest = parseFloat(container.dataset.west);
|
||||||
|
const tileUrl = container.dataset.tileUrl || 'https://tile.openstreetmap.org/{z}/{x}/{y}.png';
|
||||||
|
const tileAttr = container.dataset.tileAttr || '© OpenStreetMap contributors';
|
||||||
|
|
||||||
|
const centerLat = (savedNorth + savedSouth) / 2;
|
||||||
|
const centerLng = (savedEast + savedWest) / 2;
|
||||||
|
const mapEl = document.getElementById('region-map-' + adapterName);
|
||||||
|
const map = L.map(mapEl).setView([centerLat, centerLng], 4);
|
||||||
|
|
||||||
|
L.tileLayer(tileUrl, {
|
||||||
|
attribution: tileAttr,
|
||||||
|
maxZoom: 18
|
||||||
|
}).addTo(map);
|
||||||
|
|
||||||
|
// Ensure map renders correctly even if container has not
|
||||||
|
// finished laying out at init time
|
||||||
|
setTimeout(function() { map.invalidateSize(); }, 100);
|
||||||
|
|
||||||
|
const bounds = L.latLngBounds(
|
||||||
|
L.latLng(savedSouth, savedWest),
|
||||||
|
L.latLng(savedNorth, savedEast)
|
||||||
|
);
|
||||||
|
map.fitBounds(bounds.pad(0.1));
|
||||||
|
|
||||||
|
let rectangle = L.rectangle(bounds, {
|
||||||
|
color: '#3388ff',
|
||||||
|
weight: 2,
|
||||||
|
fillOpacity: 0.2
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set up Leaflet.draw for click-to-draw
|
||||||
|
const drawnItems = new L.FeatureGroup();
|
||||||
|
drawnItems.addLayer(rectangle);
|
||||||
|
map.addLayer(drawnItems);
|
||||||
|
|
||||||
|
const drawControl = new L.Control.Draw({
|
||||||
|
draw: {
|
||||||
|
rectangle: { shapeOptions: { color: '#3388ff', weight: 2,
|
||||||
|
fillOpacity: 0.2 } },
|
||||||
|
polyline: false,
|
||||||
|
polygon: false,
|
||||||
|
circle: false,
|
||||||
|
marker: false,
|
||||||
|
circlemarker: false
|
||||||
|
},
|
||||||
|
edit: {
|
||||||
|
featureGroup: drawnItems,
|
||||||
|
edit: false,
|
||||||
|
remove: false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
map.addControl(drawControl);
|
||||||
|
|
||||||
|
rectangle.editing.enable();
|
||||||
|
|
||||||
|
const northInput = container.querySelector('input[name="' + adapterName + '_region_north"]');
|
||||||
|
const southInput = container.querySelector('input[name="' + adapterName + '_region_south"]');
|
||||||
|
const eastInput = container.querySelector('input[name="' + adapterName + '_region_east"]');
|
||||||
|
const westInput = container.querySelector('input[name="' + adapterName + '_region_west"]');
|
||||||
|
|
||||||
|
function updateInputs() {
|
||||||
|
const b = rectangle.getBounds();
|
||||||
|
northInput.value = b.getNorth().toFixed(4);
|
||||||
|
southInput.value = b.getSouth().toFixed(4);
|
||||||
|
eastInput.value = b.getEast().toFixed(4);
|
||||||
|
westInput.value = b.getWest().toFixed(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
rectangle.on('edit', updateInputs);
|
||||||
|
updateInputs();
|
||||||
|
|
||||||
|
// When user draws a new rectangle, replace the existing one
|
||||||
|
map.on(L.Draw.Event.CREATED, function(e) {
|
||||||
|
drawnItems.clearLayers();
|
||||||
|
rectangle = e.layer;
|
||||||
|
rectangle.setStyle({ color: '#3388ff', weight: 2,
|
||||||
|
fillOpacity: 0.2 });
|
||||||
|
drawnItems.addLayer(rectangle);
|
||||||
|
rectangle.editing.enable();
|
||||||
|
rectangle.on('edit', updateInputs);
|
||||||
|
updateInputs();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Fix map size when details is opened
|
||||||
|
const details = container.closest('details');
|
||||||
|
if (details) {
|
||||||
|
details.addEventListener('toggle', function() {
|
||||||
|
setTimeout(function() { map.invalidateSize(); }, 100);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
||||||
73
src/central/gui/templates/setup_finish.html
Normal file
73
src/central/gui/templates/setup_finish.html
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
{% extends "base_wizard.html" %}
|
||||||
|
|
||||||
|
{% block title %}Central - Finish Setup{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
{% with step=5, step_name="Finish Setup" %}
|
||||||
|
{% include "_wizard_header.html" %}
|
||||||
|
{% endwith %}
|
||||||
|
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>Setup Complete</h1>
|
||||||
|
<p>Review your configuration and finish the setup wizard.</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<h2>Summary</h2>
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<th>Operators</th>
|
||||||
|
<td>{{ operator_count }} configured</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th>API Keys</th>
|
||||||
|
<td>{{ key_count }} configured</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th>Map Tile URL</th>
|
||||||
|
<td style="word-break: break-all;">{{ system.map_tile_url }}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
<h3>Adapters</h3>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>Adapter</th>
|
||||||
|
<th>Status</th>
|
||||||
|
<th>Cadence</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for adapter in adapters %}
|
||||||
|
<tr>
|
||||||
|
<td><strong>{{ adapter.name }}</strong></td>
|
||||||
|
<td>
|
||||||
|
{% if adapter.enabled %}
|
||||||
|
<span style="color: var(--pico-color-green-500);">Enabled</span>
|
||||||
|
{% else %}
|
||||||
|
<span style="color: var(--pico-color-grey-500);">Disabled</span>
|
||||||
|
{% endif %}
|
||||||
|
</td>
|
||||||
|
<td>{{ adapter.cadence_s }}s</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
<form action="/setup/finish" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
<div style="display: flex; gap: 1rem; margin-top: 2rem;">
|
||||||
|
<a href="/setup/adapters" role="button" class="outline">← Back</a>
|
||||||
|
<button type="submit">Finish Setup</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</article>
|
||||||
|
{% endblock %}
|
||||||
88
src/central/gui/templates/setup_keys.html
Normal file
88
src/central/gui/templates/setup_keys.html
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
{% extends "base_wizard.html" %}
|
||||||
|
|
||||||
|
{% block title %}Central - API Keys{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
{% with step=3, step_name="API Keys" %}
|
||||||
|
{% include "_wizard_header.html" %}
|
||||||
|
{% endwith %}
|
||||||
|
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>API Keys</h1>
|
||||||
|
<p>Add API keys for adapters that require external service credentials (e.g., FIRMS).</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if success %}
|
||||||
|
<p style="color: var(--pico-color-green-500);">{{ success }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if keys %}
|
||||||
|
<h2>Existing Keys</h2>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>Alias</th>
|
||||||
|
<th>Created</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for key in keys %}
|
||||||
|
<tr>
|
||||||
|
<td><strong>{{ key.alias }}</strong></td>
|
||||||
|
<td>{{ key.created_at.strftime('%Y-%m-%d %H:%M') if key.created_at else '(never)' }}</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
{% else %}
|
||||||
|
<p><em>No API keys configured yet.</em></p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<h2>Add New Key</h2>
|
||||||
|
<form action="/setup/keys" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
<input type="hidden" name="action" value="add">
|
||||||
|
|
||||||
|
<div class="grid">
|
||||||
|
<div>
|
||||||
|
<label for="alias">Alias</label>
|
||||||
|
<input type="text" id="alias" name="alias" placeholder="e.g., firms"
|
||||||
|
value="{{ form_data.alias if form_data else '' }}" maxlength="64">
|
||||||
|
{% if errors and errors.alias %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors.alias }}</small>
|
||||||
|
{% else %}
|
||||||
|
<small>Letters, numbers, and underscores only.</small>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="plaintext_key">API Key</label>
|
||||||
|
<input type="password" id="plaintext_key" name="plaintext_key"
|
||||||
|
placeholder="Paste your API key">
|
||||||
|
{% if errors and errors.plaintext_key %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors.plaintext_key }}</small>
|
||||||
|
{% else %}
|
||||||
|
<small>Will be encrypted before storage.</small>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit" class="outline">Add Key</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<hr>
|
||||||
|
|
||||||
|
<form action="/setup/keys" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
<input type="hidden" name="action" value="next">
|
||||||
|
<div style="display: flex; gap: 1rem;">
|
||||||
|
<a href="/setup/system" role="button" class="outline">← Back</a>
|
||||||
|
<button type="submit">Next →</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</article>
|
||||||
|
{% endblock %}
|
||||||
57
src/central/gui/templates/setup_operator.html
Normal file
57
src/central/gui/templates/setup_operator.html
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
{% extends "base_wizard.html" %}
|
||||||
|
|
||||||
|
{% block title %}Central - Create Operator{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
{% with step=1, step_name="Create Operator" %}
|
||||||
|
{% include "_wizard_header.html" %}
|
||||||
|
{% endwith %}
|
||||||
|
|
||||||
|
{% if existing_operator %}
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>Operator Already Configured</h1>
|
||||||
|
</header>
|
||||||
|
<p>The operator account <strong>{{ existing_operator.username }}</strong> has been created.</p>
|
||||||
|
<div style="display: flex; gap: 1rem; margin-top: 1rem;">
|
||||||
|
<a href="/setup/system" role="button">Next →</a>
|
||||||
|
</div>
|
||||||
|
</article>
|
||||||
|
{% else %}
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>Create Operator Account</h1>
|
||||||
|
<p>Create the initial operator account to manage Central.</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<form action="/setup/operator" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
|
||||||
|
<label for="username">
|
||||||
|
Username
|
||||||
|
<input type="text" id="username" name="username" required
|
||||||
|
autocomplete="username" autofocus value="{{ form_data.username if form_data else '' }}">
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label for="password">
|
||||||
|
Password
|
||||||
|
<input type="password" id="password" name="password" required
|
||||||
|
autocomplete="new-password" minlength="8">
|
||||||
|
<small>Minimum 8 characters</small>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label for="confirm_password">
|
||||||
|
Confirm Password
|
||||||
|
<input type="password" id="confirm_password" name="confirm_password" required
|
||||||
|
autocomplete="new-password">
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<button type="submit">Create Operator →</button>
|
||||||
|
</form>
|
||||||
|
</article>
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
49
src/central/gui/templates/setup_system.html
Normal file
49
src/central/gui/templates/setup_system.html
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
{% extends "base_wizard.html" %}
|
||||||
|
|
||||||
|
{% block title %}Central - System Settings{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
{% with step=2, step_name="System Settings" %}
|
||||||
|
{% include "_wizard_header.html" %}
|
||||||
|
{% endwith %}
|
||||||
|
|
||||||
|
<article>
|
||||||
|
<header>
|
||||||
|
<h1>System Settings</h1>
|
||||||
|
<p>Configure map tile provider for the region picker.</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
{% if error %}
|
||||||
|
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<form action="/setup/system" method="post">
|
||||||
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
|
||||||
|
<label for="map_tile_url">
|
||||||
|
Map Tile URL
|
||||||
|
<input type="text" id="map_tile_url" name="map_tile_url"
|
||||||
|
value="{{ form_data.map_tile_url if form_data else system.map_tile_url }}" required>
|
||||||
|
<small>Use {z}, {x}, {y} placeholders. Example: https://tile.openstreetmap.org/{z}/{x}/{y}.png</small>
|
||||||
|
</label>
|
||||||
|
{% if errors and errors.map_tile_url %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors.map_tile_url }}</small>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<label for="map_attribution">
|
||||||
|
Map Attribution
|
||||||
|
<input type="text" id="map_attribution" name="map_attribution"
|
||||||
|
value="{{ form_data.map_attribution if form_data else system.map_attribution }}" required>
|
||||||
|
<small>Credit the map provider (required by most tile services).</small>
|
||||||
|
</label>
|
||||||
|
{% if errors and errors.map_attribution %}
|
||||||
|
<small style="color: var(--pico-color-red-500);">{{ errors.map_attribution }}</small>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<div style="display: flex; gap: 1rem; margin-top: 1rem;">
|
||||||
|
<a href="/setup/operator" role="button" class="outline">← Back</a>
|
||||||
|
<button type="submit">Next →</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</article>
|
||||||
|
{% endblock %}
|
||||||
|
|
@ -48,3 +48,49 @@ def mock_conn():
|
||||||
conn.fetchval = AsyncMock()
|
conn.fetchval = AsyncMock()
|
||||||
conn.execute = AsyncMock()
|
conn.execute = AsyncMock()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
# CSRF fixtures for route tests
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bypass_pre_auth_csrf():
|
||||||
|
"""Patch pre-auth CSRF validation to always pass.
|
||||||
|
|
||||||
|
Use for tests of pre-auth routes: /login, /setup/operator
|
||||||
|
"""
|
||||||
|
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||||
|
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_csrf_token", "test_signed_token")):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bypass_session_csrf():
|
||||||
|
"""Create a mock request with session CSRF properly configured.
|
||||||
|
|
||||||
|
Use for tests of authenticated routes that check request.state.csrf_token.
|
||||||
|
Returns a configured mock_request.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.state.csrf_token = "test_csrf_token_12345"
|
||||||
|
request.state.operator = MagicMock()
|
||||||
|
request.state.operator.id = 1
|
||||||
|
request.state.operator.username = "testuser"
|
||||||
|
|
||||||
|
# Mock form() to return dict with matching CSRF token
|
||||||
|
form_data = {"csrf_token": "test_csrf_token_12345"}
|
||||||
|
|
||||||
|
async def mock_form():
|
||||||
|
return form_data
|
||||||
|
|
||||||
|
request.form = mock_form
|
||||||
|
request._form_data = form_data # Allow tests to modify form data
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_route_settings():
|
||||||
|
"""Patch get_settings in routes module."""
|
||||||
|
with patch("central.gui.routes.get_settings") as mock:
|
||||||
|
mock.return_value.csrf_secret = "test-csrf-secret-for-testing-only-32chars"
|
||||||
|
yield mock
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,9 @@ class TestAdaptersListAuthenticated:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_list(mock_request, mock_csrf)
|
result = await adapters_list(mock_request)
|
||||||
|
|
||||||
# Verify template was called with adapters
|
# Verify template was called with adapters
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
@ -105,13 +101,9 @@ class TestAdaptersEditForm:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_form(mock_request, "nws", mock_csrf)
|
result = await adapters_edit_form(mock_request, "nws")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -133,11 +125,8 @@ class TestAdaptersEditForm:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_form(mock_request, "nonexistent", mock_csrf)
|
result = await adapters_edit_form(mock_request, "nonexistent")
|
||||||
|
|
||||||
assert result.status_code == 404
|
assert result.status_code == 404
|
||||||
|
|
||||||
|
|
@ -156,7 +145,9 @@ class TestAdaptersEditSubmit:
|
||||||
|
|
||||||
# Mock form data
|
# Mock form data
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "120",
|
"cadence_s": "120",
|
||||||
"contact_email": "new@example.com",
|
"contact_email": "new@example.com",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -183,12 +174,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "nws")
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
assert result.headers["location"] == "/adapters"
|
assert result.headers["location"] == "/adapters"
|
||||||
|
|
@ -204,7 +192,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "30",
|
"cadence_s": "30",
|
||||||
"contact_email": "test@example.com",
|
"contact_email": "test@example.com",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -239,14 +229,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "nws")
|
||||||
|
|
||||||
# Should re-render form with error
|
# Should re-render form with error
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
@ -263,7 +248,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "nonexistent_key",
|
"api_key_alias": "nonexistent_key",
|
||||||
"region_north": "49.5",
|
"region_north": "49.5",
|
||||||
|
|
@ -299,14 +286,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -322,7 +304,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "120",
|
"cadence_s": "120",
|
||||||
"feed": "invalid_feed",
|
"feed": "invalid_feed",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -357,14 +341,9 @@ class TestAdaptersEditSubmit:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "usgs_quake", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "usgs_quake")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -383,7 +362,9 @@ class TestAdaptersAudit:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "120",
|
"cadence_s": "120",
|
||||||
"contact_email": "new@example.com",
|
"contact_email": "new@example.com",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -410,9 +391,6 @@ class TestAdaptersAudit:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_audit = {}
|
captured_audit = {}
|
||||||
|
|
||||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||||
|
|
@ -423,7 +401,7 @@ class TestAdaptersAudit:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "nws")
|
||||||
|
|
||||||
assert captured_audit["action"] == "adapter.update"
|
assert captured_audit["action"] == "adapter.update"
|
||||||
assert captured_audit["target"] == "nws"
|
assert captured_audit["target"] == "nws"
|
||||||
|
|
@ -449,7 +427,9 @@ class TestAdaptersJsonbRegression:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "120",
|
"cadence_s": "120",
|
||||||
"contact_email": "test@example.com",
|
"contact_email": "test@example.com",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -476,12 +456,9 @@ class TestAdaptersJsonbRegression:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
await adapters_edit_submit(mock_request, "nws")
|
||||||
|
|
||||||
# Get the settings argument passed to execute (3rd positional arg after query)
|
# Get the settings argument passed to execute (3rd positional arg after query)
|
||||||
call_args = mock_conn.execute.call_args
|
call_args = mock_conn.execute.call_args
|
||||||
|
|
@ -502,7 +479,9 @@ class TestAdaptersJsonbRegression:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "120",
|
"cadence_s": "120",
|
||||||
"contact_email": "new@example.com",
|
"contact_email": "new@example.com",
|
||||||
"region_north": "49.0",
|
"region_north": "49.0",
|
||||||
|
|
@ -529,9 +508,6 @@ class TestAdaptersJsonbRegression:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_audit = {}
|
captured_audit = {}
|
||||||
|
|
||||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||||
|
|
@ -540,7 +516,7 @@ class TestAdaptersJsonbRegression:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||||
await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
await adapters_edit_submit(mock_request, "nws")
|
||||||
|
|
||||||
# CRITICAL: before and after must be dicts, NOT strings
|
# CRITICAL: before and after must be dicts, NOT strings
|
||||||
assert isinstance(captured_audit["before"], dict), f"before should be dict, got {type(captured_audit['before'])}"
|
assert isinstance(captured_audit["before"], dict), f"before should be dict, got {type(captured_audit['before'])}"
|
||||||
|
|
|
||||||
|
|
@ -75,13 +75,9 @@ class TestApiKeysListAuthenticated:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await api_keys_list(mock_request, mock_csrf)
|
result = await api_keys_list(mock_request)
|
||||||
|
|
||||||
# Check template was called with correct context
|
# Check template was called with correct context
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
@ -104,7 +100,8 @@ class TestApiKeysCreate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"alias": "test1", "plaintext_key": "secret-api-key-123"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "test1", "plaintext_key": "secret-api-key-123"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -119,13 +116,10 @@ class TestApiKeysCreate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.crypto.encrypt", return_value=b"encrypted_data"):
|
with patch("central.crypto.encrypt", return_value=b"encrypted_data"):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
result = await api_keys_create(mock_request, mock_csrf)
|
result = await api_keys_create(mock_request)
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
assert result.headers["location"] == "/api-keys"
|
assert result.headers["location"] == "/api-keys"
|
||||||
|
|
@ -136,7 +130,8 @@ class TestApiKeysCreate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"alias": "firms", "plaintext_key": "secret-key"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "firms", "plaintext_key": "secret-key"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_templates = MagicMock()
|
mock_templates = MagicMock()
|
||||||
|
|
@ -150,15 +145,10 @@ class TestApiKeysCreate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
||||||
result = await api_keys_create(mock_request, mock_csrf)
|
result = await api_keys_create(mock_request)
|
||||||
|
|
||||||
# Should re-render form with error
|
# Should re-render form with error
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
@ -172,7 +162,8 @@ class TestApiKeysCreate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"alias": "", "plaintext_key": "secret-key"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "", "plaintext_key": "secret-key"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_templates = MagicMock()
|
mock_templates = MagicMock()
|
||||||
|
|
@ -183,14 +174,9 @@ class TestApiKeysCreate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await api_keys_create(mock_request, mock_csrf)
|
result = await api_keys_create(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -203,7 +189,8 @@ class TestApiKeysCreate:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
# Test with space
|
# Test with space
|
||||||
form_data = {"alias": "test key", "plaintext_key": "secret-key"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "test key", "plaintext_key": "secret-key"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_templates = MagicMock()
|
mock_templates = MagicMock()
|
||||||
|
|
@ -214,14 +201,9 @@ class TestApiKeysCreate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await api_keys_create(mock_request, mock_csrf)
|
result = await api_keys_create(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -233,7 +215,8 @@ class TestApiKeysCreate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"alias": "test-key", "plaintext_key": "secret-key"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "test-key", "plaintext_key": "secret-key"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_templates = MagicMock()
|
mock_templates = MagicMock()
|
||||||
|
|
@ -244,14 +227,9 @@ class TestApiKeysCreate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await api_keys_create(mock_request, mock_csrf)
|
result = await api_keys_create(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -267,7 +245,8 @@ class TestApiKeysRotate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"new_plaintext_key": "new-secret-key-456"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "new_plaintext_key": "new-secret-key-456"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -290,13 +269,10 @@ class TestApiKeysRotate:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.crypto.encrypt", return_value=b"new_encrypted"):
|
with patch("central.crypto.encrypt", return_value=b"new_encrypted"):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
result = await api_keys_rotate(mock_request, "test1", mock_csrf)
|
result = await api_keys_rotate(mock_request, "test1")
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
# Check audit was called with no plaintext
|
# Check audit was called with no plaintext
|
||||||
|
|
@ -313,6 +289,8 @@ class TestApiKeysDelete:
|
||||||
"""POST /api-keys/{alias}/delete with references shows error."""
|
"""POST /api-keys/{alias}/delete with references shows error."""
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"})
|
||||||
|
|
||||||
mock_templates = MagicMock()
|
mock_templates = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = MagicMock()
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
@ -331,14 +309,9 @@ class TestApiKeysDelete:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await api_keys_delete(mock_request, "firms", mock_csrf)
|
result = await api_keys_delete(mock_request, "firms")
|
||||||
|
|
||||||
# Should re-render with error
|
# Should re-render with error
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
@ -351,6 +324,8 @@ class TestApiKeysDelete:
|
||||||
"""POST /api-keys/{alias}/delete without references deletes and redirects."""
|
"""POST /api-keys/{alias}/delete without references deletes and redirects."""
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"})
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.fetchrow.return_value = {
|
mock_conn.fetchrow.return_value = {
|
||||||
|
|
@ -367,12 +342,9 @@ class TestApiKeysDelete:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
result = await api_keys_delete(mock_request, "test1", mock_csrf)
|
result = await api_keys_delete(mock_request, "test1")
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
assert result.headers["location"] == "/api-keys"
|
assert result.headers["location"] == "/api-keys"
|
||||||
|
|
@ -388,7 +360,8 @@ class TestApiKeysAuditNoPlaintext:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="admin")
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
form_data = {"alias": "newkey", "plaintext_key": "super-secret-value"}
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
form_data = {"csrf_token": "test_csrf_token", "alias": "newkey", "plaintext_key": "super-secret-value"}
|
||||||
mock_request.form = AsyncMock(return_value=form_data)
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -401,13 +374,10 @@ class TestApiKeysAuditNoPlaintext:
|
||||||
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
mock_pool.acquire.return_value.__aexit__.return_value = None
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
await api_keys_create(mock_request, mock_csrf)
|
await api_keys_create(mock_request)
|
||||||
|
|
||||||
# Check audit call arguments
|
# Check audit call arguments
|
||||||
call_kwargs = mock_audit.call_args.kwargs
|
call_kwargs = mock_audit.call_args.kwargs
|
||||||
|
|
|
||||||
|
|
@ -92,29 +92,33 @@ class TestSessionManagement:
|
||||||
mock_conn = MagicMock()
|
mock_conn = MagicMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
|
|
||||||
token, expires_at = await create_session(mock_conn, operator_id=1, lifetime_days=90)
|
token, expires_at, csrf_token = await create_session(mock_conn, operator_id=1, lifetime_days=90)
|
||||||
|
|
||||||
assert len(token) == 43
|
assert len(token) == 43
|
||||||
|
assert len(csrf_token) == 64 # 32 bytes hex = 64 chars
|
||||||
mock_conn.execute.assert_called_once()
|
mock_conn.execute.assert_called_once()
|
||||||
call_args = mock_conn.execute.call_args
|
call_args = mock_conn.execute.call_args
|
||||||
assert "INSERT INTO config.sessions" in call_args[0][0]
|
assert "INSERT INTO config.sessions" in call_args[0][0]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_session_found(self):
|
async def test_get_session_found(self):
|
||||||
"""get_session returns Operator when session exists."""
|
"""get_session returns (Operator, csrf_token) when session exists."""
|
||||||
mock_conn = MagicMock()
|
mock_conn = MagicMock()
|
||||||
mock_conn.fetchrow = AsyncMock(return_value={
|
mock_conn.fetchrow = AsyncMock(return_value={
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"username": "testuser",
|
"username": "testuser",
|
||||||
"created_at": datetime.now(timezone.utc),
|
"created_at": datetime.now(timezone.utc),
|
||||||
"password_changed_at": datetime.now(timezone.utc),
|
"password_changed_at": datetime.now(timezone.utc),
|
||||||
|
"csrf_token": "test_csrf_token_12345",
|
||||||
})
|
})
|
||||||
|
|
||||||
operator = await get_session(mock_conn, "valid-token")
|
result = await get_session(mock_conn, "valid-token")
|
||||||
|
|
||||||
assert operator is not None
|
assert result is not None
|
||||||
|
operator, csrf_token = result
|
||||||
assert operator.id == 1
|
assert operator.id == 1
|
||||||
assert operator.username == "testuser"
|
assert operator.username == "testuser"
|
||||||
|
assert csrf_token == "test_csrf_token_12345"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_session_not_found(self):
|
async def test_get_session_not_found(self):
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) ->
|
||||||
clear_key_cache()
|
clear_key_cache()
|
||||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||||
|
monkeypatch.setenv("CENTRAL_CSRF_SECRET", "test-csrf-secret-for-testing-only-32chars")
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|
|
||||||
|
|
@ -14,23 +14,18 @@ class TestCsrfExceptionHandlerRegistered:
|
||||||
"""Verify CSRF exception handler is properly registered."""
|
"""Verify CSRF exception handler is properly registered."""
|
||||||
|
|
||||||
def test_csrf_exception_handler_is_registered(self):
|
def test_csrf_exception_handler_is_registered(self):
|
||||||
"""The app has a CsrfProtectError exception handler registered."""
|
"""The app has a CsrfValidationError exception handler registered."""
|
||||||
from central.gui import app
|
from central.gui import app
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
assert CsrfProtectError in app.exception_handlers, \
|
assert CsrfValidationError in app.exception_handlers, \
|
||||||
"CsrfProtectError handler should be registered"
|
"CsrfValidationError handler should be registered"
|
||||||
|
|
||||||
def test_csrf_subclasses_are_caught(self):
|
def test_csrf_validation_error_is_exception(self):
|
||||||
"""MissingTokenError and TokenValidationError inherit from CsrfProtectError."""
|
"""CsrfValidationError is a proper Exception subclass."""
|
||||||
from fastapi_csrf_protect.exceptions import (
|
from central.gui.auth import CsrfValidationError
|
||||||
CsrfProtectError,
|
|
||||||
MissingTokenError,
|
|
||||||
TokenValidationError,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert issubclass(MissingTokenError, CsrfProtectError)
|
assert issubclass(CsrfValidationError, Exception)
|
||||||
assert issubclass(TokenValidationError, CsrfProtectError)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCsrfExceptionHandlerBehavior:
|
class TestCsrfExceptionHandlerBehavior:
|
||||||
|
|
@ -40,10 +35,10 @@ class TestCsrfExceptionHandlerBehavior:
|
||||||
"""CSRF handler checks request path for /login."""
|
"""CSRF handler checks request path for /login."""
|
||||||
import inspect
|
import inspect
|
||||||
from central.gui import _create_app
|
from central.gui import _create_app
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
app = _create_app()
|
app = _create_app()
|
||||||
handler = app.exception_handlers.get(CsrfProtectError)
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
# Verify handler source contains /login path check
|
# Verify handler source contains /login path check
|
||||||
source = inspect.getsource(handler)
|
source = inspect.getsource(handler)
|
||||||
|
|
@ -54,17 +49,16 @@ class TestCsrfExceptionHandlerBehavior:
|
||||||
async def test_logout_csrf_error_redirects_to_login(self):
|
async def test_logout_csrf_error_redirects_to_login(self):
|
||||||
"""CSRF error on /logout should redirect to /login."""
|
"""CSRF error on /logout should redirect to /login."""
|
||||||
from central.gui import _create_app
|
from central.gui import _create_app
|
||||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
from central.gui.auth import CsrfValidationError
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
app = _create_app()
|
app = _create_app()
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
handler = app.exception_handlers.get(CsrfProtectError)
|
|
||||||
|
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.url.path = "/logout"
|
mock_request.url.path = "/logout"
|
||||||
|
|
||||||
exc = TokenValidationError("Invalid token")
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
result = await handler(mock_request, exc)
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
|
@ -75,17 +69,16 @@ class TestCsrfExceptionHandlerBehavior:
|
||||||
async def test_adapters_csrf_error_redirects_to_adapters(self):
|
async def test_adapters_csrf_error_redirects_to_adapters(self):
|
||||||
"""CSRF error on /adapters/{name} should redirect to /adapters."""
|
"""CSRF error on /adapters/{name} should redirect to /adapters."""
|
||||||
from central.gui import _create_app
|
from central.gui import _create_app
|
||||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
from central.gui.auth import CsrfValidationError
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
app = _create_app()
|
app = _create_app()
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
handler = app.exception_handlers.get(CsrfProtectError)
|
|
||||||
|
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.url.path = "/adapters/nws"
|
mock_request.url.path = "/adapters/nws"
|
||||||
|
|
||||||
exc = TokenValidationError("Invalid token")
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
result = await handler(mock_request, exc)
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
|
@ -94,16 +87,171 @@ class TestCsrfExceptionHandlerBehavior:
|
||||||
|
|
||||||
|
|
||||||
class TestCsrfHandlerNoTraceback:
|
class TestCsrfHandlerNoTraceback:
|
||||||
"""Verify exception handler doesn't expose Python internals."""
|
"""Verify exception handler does not expose Python internals."""
|
||||||
|
|
||||||
def test_handler_exists_and_is_async(self):
|
def test_handler_exists_and_is_async(self):
|
||||||
"""The CSRF handler should be an async function."""
|
"""The CSRF handler should be an async function."""
|
||||||
import inspect
|
import inspect
|
||||||
from central.gui import _create_app
|
from central.gui import _create_app
|
||||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
app = _create_app()
|
app = _create_app()
|
||||||
handler = app.exception_handlers.get(CsrfProtectError)
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
assert handler is not None
|
assert handler is not None
|
||||||
assert inspect.iscoroutinefunction(handler)
|
assert inspect.iscoroutinefunction(handler)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCsrfHandlerWizardPaths:
|
||||||
|
"""Test CSRF exception handler for wizard paths."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_operator_csrf_error_renders_form_with_error(self):
|
||||||
|
"""CSRF error on /setup/operator re-renders form with error message."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup/operator"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
# Should be HTML response, not redirect
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_system_csrf_error_renders_form_with_error(self):
|
||||||
|
"""CSRF error on /setup/system re-renders form with error message."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup/system"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
with patch("central.gui.db.get_pool", return_value=None):
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_keys_csrf_error_renders_form_with_error(self):
|
||||||
|
"""CSRF error on /setup/keys re-renders form with error message."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup/keys"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
with patch("central.gui.db.get_pool", return_value=None):
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_adapters_csrf_error_renders_form_with_error(self):
|
||||||
|
"""CSRF error on /setup/adapters re-renders form with error message."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup/adapters"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
with patch("central.gui.db.get_pool", return_value=None):
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_finish_csrf_error_renders_form_with_error(self):
|
||||||
|
"""CSRF error on /setup/finish re-renders form with error message."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup/finish"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
with patch("central.gui.db.get_pool", return_value=None):
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_base_csrf_error_redirects_to_setup(self):
|
||||||
|
"""CSRF error on /setup redirects to /setup (middleware routes to step)."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/setup"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert isinstance(result, RedirectResponse)
|
||||||
|
assert result.status_code == 302
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_csrf_error_still_works(self):
|
||||||
|
"""CSRF error on /login still renders login form with error (regression test)."""
|
||||||
|
from central.gui import _create_app
|
||||||
|
from central.gui.auth import CsrfValidationError
|
||||||
|
|
||||||
|
app = _create_app()
|
||||||
|
handler = app.exception_handlers.get(CsrfValidationError)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.url.path = "/login"
|
||||||
|
|
||||||
|
exc = CsrfValidationError("Invalid token")
|
||||||
|
|
||||||
|
result = await handler(mock_request, exc)
|
||||||
|
|
||||||
|
assert hasattr(result, "body")
|
||||||
|
assert result.status_code == 200
|
||||||
|
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||||
|
assert "session expired" in body.lower()
|
||||||
|
|
|
||||||
108
tests/test_csrf_race_condition.py
Normal file
108
tests/test_csrf_race_condition.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""
|
||||||
|
Integration test for CSRF race condition fix.
|
||||||
|
|
||||||
|
This test verifies that the session-bound CSRF implementation fixes the race
|
||||||
|
condition where interleaved GET requests would invalidate CSRF tokens.
|
||||||
|
|
||||||
|
See: PR #24 - Central 1b-8 fix-up phase 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestCsrfRaceConditionFix:
|
||||||
|
"""Verify that interleaved GETs don't break CSRF validation."""
|
||||||
|
|
||||||
|
def test_session_bound_csrf_consistent_across_gets(self):
|
||||||
|
"""Session-bound CSRF tokens remain consistent across multiple GETs.
|
||||||
|
|
||||||
|
This was the core bug: fastapi-csrf-protect rotated tokens on every GET,
|
||||||
|
causing race conditions when users had multiple tabs or slow connections.
|
||||||
|
|
||||||
|
With session-bound CSRF, the token is stored in the session row and
|
||||||
|
remains constant until the session is destroyed.
|
||||||
|
"""
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
from central.gui.auth import get_session
|
||||||
|
|
||||||
|
# Mock a session with a csrf_token
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={
|
||||||
|
"id": 1,
|
||||||
|
"username": "testuser",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"password_changed_at": "2024-01-01T00:00:00Z",
|
||||||
|
"csrf_token": "fixed_csrf_token_12345",
|
||||||
|
})
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
# First GET
|
||||||
|
result1 = await get_session(mock_conn, "test-token")
|
||||||
|
assert result1 is not None
|
||||||
|
op1, csrf1 = result1
|
||||||
|
|
||||||
|
# Second GET (simulating interleaved request)
|
||||||
|
result2 = await get_session(mock_conn, "test-token")
|
||||||
|
assert result2 is not None
|
||||||
|
op2, csrf2 = result2
|
||||||
|
|
||||||
|
# CSRF tokens should be identical (the fix!)
|
||||||
|
assert csrf1 == csrf2 == "fixed_csrf_token_12345"
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
|
||||||
|
def test_pre_auth_csrf_tokens_independently_valid(self):
|
||||||
|
"""Pre-auth CSRF tokens are independently valid.
|
||||||
|
|
||||||
|
For unauthenticated routes, each GET generates a new token+cookie pair.
|
||||||
|
Each pair should validate independently, allowing the original token
|
||||||
|
to work even if another GET happened in between.
|
||||||
|
"""
|
||||||
|
from central.gui.csrf import generate_pre_auth_csrf, validate_pre_auth_csrf
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
secret = "testsecret12345678901234567890ab"
|
||||||
|
|
||||||
|
# First GET generates token1 + cookie1
|
||||||
|
token1, signed1 = generate_pre_auth_csrf(secret)
|
||||||
|
|
||||||
|
# Second GET generates token2 + cookie2
|
||||||
|
token2, signed2 = generate_pre_auth_csrf(secret)
|
||||||
|
|
||||||
|
# Tokens should be different (fresh random tokens)
|
||||||
|
assert token1 != token2
|
||||||
|
assert signed1 != signed2
|
||||||
|
|
||||||
|
# But each pair should validate independently
|
||||||
|
mock_request1 = MagicMock()
|
||||||
|
mock_request1.cookies = {"central_preauth_csrf": signed1}
|
||||||
|
|
||||||
|
mock_request2 = MagicMock()
|
||||||
|
mock_request2.cookies = {"central_preauth_csrf": signed2}
|
||||||
|
|
||||||
|
# Original token still validates with original cookie
|
||||||
|
assert validate_pre_auth_csrf(mock_request1, token1, secret) is True
|
||||||
|
|
||||||
|
# Second token validates with second cookie
|
||||||
|
assert validate_pre_auth_csrf(mock_request2, token2, secret) is True
|
||||||
|
|
||||||
|
# Cross-validation should fail
|
||||||
|
assert validate_pre_auth_csrf(mock_request1, token2, secret) is False
|
||||||
|
assert validate_pre_auth_csrf(mock_request2, token1, secret) is False
|
||||||
|
|
||||||
|
def test_csrf_token_generation_is_secure(self):
|
||||||
|
"""CSRF tokens are cryptographically secure."""
|
||||||
|
from central.gui.auth import generate_csrf_token
|
||||||
|
|
||||||
|
# Generate multiple tokens
|
||||||
|
tokens = [generate_csrf_token() for _ in range(100)]
|
||||||
|
|
||||||
|
# All tokens should be unique
|
||||||
|
assert len(set(tokens)) == 100
|
||||||
|
|
||||||
|
# Tokens should be 64 hex chars (32 bytes)
|
||||||
|
for token in tokens:
|
||||||
|
assert len(token) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in token)
|
||||||
|
|
@ -51,13 +51,9 @@ class TestRegionPickerInTemplate:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_form(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_form(mock_request, "firms")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -79,7 +75,9 @@ class TestRegionValidation:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "firms",
|
"api_key_alias": "firms",
|
||||||
"region_north": "45.0",
|
"region_north": "45.0",
|
||||||
|
|
@ -109,9 +107,6 @@ class TestRegionValidation:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_settings = {}
|
captured_settings = {}
|
||||||
|
|
||||||
async def capture_execute(query, *args):
|
async def capture_execute(query, *args):
|
||||||
|
|
@ -122,7 +117,7 @@ class TestRegionValidation:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
assert captured_settings["settings"]["region"]["north"] == 45.0
|
assert captured_settings["settings"]["region"]["north"] == 45.0
|
||||||
|
|
@ -139,7 +134,9 @@ class TestRegionValidation:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "firms",
|
"api_key_alias": "firms",
|
||||||
"region_north": "30.0", # Less than south!
|
"region_north": "30.0", # Less than south!
|
||||||
|
|
@ -175,14 +172,9 @@ class TestRegionValidation:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -198,7 +190,9 @@ class TestRegionValidation:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "firms",
|
"api_key_alias": "firms",
|
||||||
"region_north": "45.0",
|
"region_north": "45.0",
|
||||||
|
|
@ -234,14 +228,9 @@ class TestRegionValidation:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -257,7 +246,9 @@ class TestRegionValidation:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "firms",
|
"api_key_alias": "firms",
|
||||||
"region_north": "95.0", # > 90!
|
"region_north": "95.0", # > 90!
|
||||||
|
|
@ -293,14 +284,9 @@ class TestRegionValidation:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -319,7 +305,9 @@ class TestRegionAuditLog:
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form.get.side_effect = lambda k, d="": {
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
"cadence_s": "300",
|
"cadence_s": "300",
|
||||||
"api_key_alias": "firms",
|
"api_key_alias": "firms",
|
||||||
"region_north": "45.0",
|
"region_north": "45.0",
|
||||||
|
|
@ -352,9 +340,6 @@ class TestRegionAuditLog:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_audit = {}
|
captured_audit = {}
|
||||||
|
|
||||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||||
|
|
@ -363,7 +348,7 @@ class TestRegionAuditLog:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
# Before should have old region
|
# Before should have old region
|
||||||
assert captured_audit["before"]["settings"]["region"]["north"] == 49.5
|
assert captured_audit["before"]["settings"]["region"]["north"] == 49.5
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ class TestSessionMiddleware:
|
||||||
"username": "admin",
|
"username": "admin",
|
||||||
"created_at": datetime.now(timezone.utc),
|
"created_at": datetime.now(timezone.utc),
|
||||||
"password_changed_at": datetime.now(timezone.utc),
|
"password_changed_at": datetime.now(timezone.utc),
|
||||||
|
"csrf_token": "mock_csrf_token_12345",
|
||||||
})
|
})
|
||||||
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_conn.__aexit__ = AsyncMock()
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
|
@ -99,6 +100,7 @@ class TestSessionMiddleware:
|
||||||
"username": "admin",
|
"username": "admin",
|
||||||
"created_at": datetime.now(timezone.utc),
|
"created_at": datetime.now(timezone.utc),
|
||||||
"password_changed_at": datetime.now(timezone.utc),
|
"password_changed_at": datetime.now(timezone.utc),
|
||||||
|
"csrf_token": "mock_csrf_token_12345",
|
||||||
})
|
})
|
||||||
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_conn.__aexit__ = AsyncMock()
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ class TestSetupGateMiddleware:
|
||||||
"""Tests for SetupGateMiddleware."""
|
"""Tests for SetupGateMiddleware."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_allows_setup_route_when_incomplete(self):
|
async def test_allows_setup_subpath_when_incomplete(self):
|
||||||
"""SetupGateMiddleware allows /setup when setup_complete=False."""
|
"""SetupGateMiddleware allows /setup/operator when setup_complete=False."""
|
||||||
mock_pool = MagicMock()
|
mock_pool = MagicMock()
|
||||||
mock_conn = MagicMock()
|
mock_conn = MagicMock()
|
||||||
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||||
|
|
@ -21,6 +21,31 @@ class TestSetupGateMiddleware:
|
||||||
mock_conn.__aexit__ = AsyncMock()
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator"}
|
||||||
|
|
||||||
|
app.add_middleware(SetupGateMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/setup/operator")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "operator"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirects_setup_base_to_wizard_step(self):
|
||||||
|
"""SetupGateMiddleware redirects /setup to wizard step when incomplete."""
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=0) # No operators
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
@ -28,12 +53,16 @@ class TestSetupGateMiddleware:
|
||||||
async def setup():
|
async def setup():
|
||||||
return {"message": "setup"}
|
return {"message": "setup"}
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator"}
|
||||||
|
|
||||||
app.add_middleware(SetupGateMiddleware)
|
app.add_middleware(SetupGateMiddleware)
|
||||||
client = TestClient(app)
|
client = TestClient(app, follow_redirects=False)
|
||||||
|
|
||||||
response = client.get("/setup")
|
response = client.get("/setup")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 302
|
||||||
assert response.json() == {"message": "setup"}
|
assert response.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_allows_health_when_incomplete(self):
|
async def test_allows_health_when_incomplete(self):
|
||||||
|
|
@ -135,7 +164,7 @@ class TestSetupGateMiddleware:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redirects_setup_when_complete(self):
|
async def test_redirects_setup_when_complete(self):
|
||||||
"""SetupGateMiddleware redirects /setup to / when setup_complete=True."""
|
"""SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""
|
||||||
mock_pool = MagicMock()
|
mock_pool = MagicMock()
|
||||||
mock_conn = MagicMock()
|
mock_conn = MagicMock()
|
||||||
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True})
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True})
|
||||||
|
|
@ -154,9 +183,18 @@ class TestSetupGateMiddleware:
|
||||||
async def setup():
|
async def setup():
|
||||||
return {"message": "setup"}
|
return {"message": "setup"}
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator"}
|
||||||
|
|
||||||
app.add_middleware(SetupGateMiddleware)
|
app.add_middleware(SetupGateMiddleware)
|
||||||
client = TestClient(app, follow_redirects=False)
|
client = TestClient(app, follow_redirects=False)
|
||||||
|
|
||||||
|
# Both /setup and /setup/operator should redirect to /
|
||||||
response = client.get("/setup")
|
response = client.get("/setup")
|
||||||
assert response.status_code == 302
|
assert response.status_code == 302
|
||||||
assert response.headers["location"] == "/"
|
assert response.headers["location"] == "/"
|
||||||
|
|
||||||
|
response = client.get("/setup/operator")
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "/"
|
||||||
|
|
|
||||||
|
|
@ -52,10 +52,6 @@ class TestStreamsListAuthenticated:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
# Mock JetStream with proper state fields
|
# Mock JetStream with proper state fields
|
||||||
mock_js = AsyncMock()
|
mock_js = AsyncMock()
|
||||||
mock_stream_info = MagicMock()
|
mock_stream_info = MagicMock()
|
||||||
|
|
@ -74,7 +70,7 @@ class TestStreamsListAuthenticated:
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -117,14 +113,10 @@ class TestStreamsListNatsUnavailable:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=None):
|
with patch("central.gui.nats.get_js", return_value=None):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -157,10 +149,6 @@ class TestStreamsListPartialFailure:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
# Mock JetStream - CENTRAL_FIRE raises ValueError, CENTRAL_WX works
|
# Mock JetStream - CENTRAL_FIRE raises ValueError, CENTRAL_WX works
|
||||||
mock_js = AsyncMock()
|
mock_js = AsyncMock()
|
||||||
test_ts = datetime(2026, 5, 17, 12, 0, 0, tzinfo=timezone.utc)
|
test_ts = datetime(2026, 5, 17, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
@ -184,7 +172,7 @@ class TestStreamsListPartialFailure:
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -222,10 +210,6 @@ class TestStreamsListEmptyStream:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
# Mock JetStream with empty stream (first_seq = 0)
|
# Mock JetStream with empty stream (first_seq = 0)
|
||||||
mock_js = AsyncMock()
|
mock_js = AsyncMock()
|
||||||
mock_stream_info = MagicMock()
|
mock_stream_info = MagicMock()
|
||||||
|
|
@ -239,7 +223,7 @@ class TestStreamsListEmptyStream:
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -278,10 +262,6 @@ class TestStreamsListSingleMessage:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
# Mock JetStream with single message (first_seq == last_seq)
|
# Mock JetStream with single message (first_seq == last_seq)
|
||||||
mock_js = AsyncMock()
|
mock_js = AsyncMock()
|
||||||
mock_stream_info = MagicMock()
|
mock_stream_info = MagicMock()
|
||||||
|
|
@ -299,7 +279,7 @@ class TestStreamsListSingleMessage:
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -337,10 +317,6 @@ class TestStreamsListGetMsgFailure:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
# Mock JetStream
|
# Mock JetStream
|
||||||
mock_js = AsyncMock()
|
mock_js = AsyncMock()
|
||||||
mock_stream_info = MagicMock()
|
mock_stream_info = MagicMock()
|
||||||
|
|
@ -365,7 +341,7 @@ class TestStreamsListGetMsgFailure:
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||||
result = await streams_list(mock_request, mock_csrf)
|
result = await streams_list(mock_request)
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -394,8 +370,12 @@ class TestStreamsUpdate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
mock_form.get.return_value = "1209600" # 14 days
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"max_age_s": "1209600",
|
||||||
|
}.get(k, d)
|
||||||
mock_request.form = AsyncMock(return_value=mock_form)
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -406,9 +386,6 @@ class TestStreamsUpdate:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_audit = {}
|
captured_audit = {}
|
||||||
|
|
||||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||||
|
|
@ -419,7 +396,7 @@ class TestStreamsUpdate:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||||
result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf)
|
result = await streams_update(mock_request, "CENTRAL_WX")
|
||||||
|
|
||||||
assert result.status_code == 302
|
assert result.status_code == 302
|
||||||
assert result.headers["location"] == "/streams"
|
assert result.headers["location"] == "/streams"
|
||||||
|
|
@ -438,8 +415,12 @@ class TestStreamsUpdate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
mock_form.get.return_value = "60" # 1 minute - too small
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"max_age_s": "60",
|
||||||
|
}.get(k, d)
|
||||||
mock_request.form = AsyncMock(return_value=mock_form)
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -458,15 +439,10 @@ class TestStreamsUpdate:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=None):
|
with patch("central.gui.nats.get_js", return_value=None):
|
||||||
result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf)
|
result = await streams_update(mock_request, "CENTRAL_WX")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -480,9 +456,10 @@ class TestStreamsUpdate:
|
||||||
|
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
mock_form.get.return_value = "999999999" # Way too large
|
mock_form.get.side_effect = lambda k, d="": {"csrf_token": "test_csrf_token", "max_age_s": "999999999"}.get(k, d) # Way too large
|
||||||
mock_request.form = AsyncMock(return_value=mock_form)
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -501,15 +478,10 @@ class TestStreamsUpdate:
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
|
||||||
mock_csrf.set_csrf_cookie = MagicMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.nats.get_js", return_value=None):
|
with patch("central.gui.nats.get_js", return_value=None):
|
||||||
result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf)
|
result = await streams_update(mock_request, "CENTRAL_WX")
|
||||||
|
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
|
@ -523,8 +495,12 @@ class TestStreamsUpdate:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
mock_form.get.return_value = "604800"
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"max_age_s": "604800",
|
||||||
|
}.get(k, d)
|
||||||
mock_request.form = AsyncMock(return_value=mock_form)
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -534,11 +510,8 @@ class TestStreamsUpdate:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await streams_update(mock_request, "nonexistent", mock_csrf)
|
result = await streams_update(mock_request, "nonexistent")
|
||||||
|
|
||||||
assert result.status_code == 404
|
assert result.status_code == 404
|
||||||
|
|
||||||
|
|
@ -554,8 +527,12 @@ class TestStreamsAudit:
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||||
|
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
mock_form = MagicMock()
|
mock_form = MagicMock()
|
||||||
mock_form.get.return_value = "1209600" # 14 days
|
mock_form.get.side_effect = lambda k, d="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"max_age_s": "1209600",
|
||||||
|
}.get(k, d)
|
||||||
mock_request.form = AsyncMock(return_value=mock_form)
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -566,9 +543,6 @@ class TestStreamsAudit:
|
||||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
mock_csrf = MagicMock()
|
|
||||||
mock_csrf.validate_csrf = AsyncMock()
|
|
||||||
|
|
||||||
captured_audit = {}
|
captured_audit = {}
|
||||||
|
|
||||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||||
|
|
@ -580,7 +554,7 @@ class TestStreamsAudit:
|
||||||
|
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||||
await streams_update(mock_request, "CENTRAL_QUAKE", mock_csrf)
|
await streams_update(mock_request, "CENTRAL_QUAKE")
|
||||||
|
|
||||||
assert captured_audit["action"] == "stream.update"
|
assert captured_audit["action"] == "stream.update"
|
||||||
assert captured_audit["operator_id"] == 1
|
assert captured_audit["operator_id"] == 1
|
||||||
|
|
|
||||||
665
tests/test_wizard.py
Normal file
665
tests/test_wizard.py
Normal file
|
|
@ -0,0 +1,665 @@
|
||||||
|
"""Tests for the first-run setup wizard."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from central.gui.routes import (
|
||||||
|
setup_operator_form,
|
||||||
|
setup_operator_submit,
|
||||||
|
setup_system_form,
|
||||||
|
setup_system_submit,
|
||||||
|
setup_keys_form,
|
||||||
|
setup_keys_submit,
|
||||||
|
setup_adapters_form,
|
||||||
|
setup_adapters_submit,
|
||||||
|
setup_finish_form,
|
||||||
|
setup_finish_submit,
|
||||||
|
)
|
||||||
|
from central.gui.middleware import SetupGateMiddleware, _get_wizard_redirect_step
|
||||||
|
|
||||||
|
|
||||||
|
class TestWizardStepRedirect:
|
||||||
|
"""Test wizard step redirect logic."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_operators_redirects_to_operator(self):
|
||||||
|
"""When no operators exist, redirect to /setup/operator."""
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.side_effect = [0] # No operators
|
||||||
|
|
||||||
|
result = await _get_wizard_redirect_step(mock_conn)
|
||||||
|
assert result == "/setup/operator"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_tile_url_redirects_to_system(self):
|
||||||
|
"""When map_tile_url is default, redirect to /setup/system."""
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.side_effect = [1] # Has operator
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await _get_wizard_redirect_step(mock_conn)
|
||||||
|
assert result == "/setup/system"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_adapters_touched_redirects_to_keys(self):
|
||||||
|
"""When no adapters have been updated, redirect to /setup/keys."""
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.side_effect = [1, 0] # Has operator, no adapters touched
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await _get_wizard_redirect_step(mock_conn)
|
||||||
|
assert result == "/setup/keys"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_steps_complete_redirects_to_finish(self):
|
||||||
|
"""When all steps done, redirect to /setup/finish."""
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.side_effect = [1, 1] # Has operator, adapters touched
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "https://custom.example.com/{z}/{x}/{y}.png"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await _get_wizard_redirect_step(mock_conn)
|
||||||
|
assert result == "/setup/finish"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupOperatorForm:
|
||||||
|
"""Test operator creation form (step 1)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_returns_form(self):
|
||||||
|
"""GET /setup/operator returns the form when no operator exists."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.return_value = None # No operator exists
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value.csrf_secret = "testsecret"
|
||||||
|
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
|
||||||
|
result = await setup_operator_form(mock_request)
|
||||||
|
|
||||||
|
mock_templates.TemplateResponse.assert_called_once()
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert "csrf_token" in context and context["csrf_token"]
|
||||||
|
assert context["error"] is None
|
||||||
|
assert context["existing_operator"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_returns_confirmation_when_operator_exists(self):
|
||||||
|
"""GET /setup/operator shows confirmation when operator already exists."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.body = b"Operator Already Configured"
|
||||||
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.return_value = {"username": "admin"} # Operator exists
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value.csrf_secret = "testsecret"
|
||||||
|
with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")):
|
||||||
|
result = await setup_operator_form(mock_request)
|
||||||
|
|
||||||
|
mock_templates.TemplateResponse.assert_called_once()
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert context["existing_operator"] == {"username": "admin"}
|
||||||
|
assert context["error"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupOperatorSubmit:
|
||||||
|
"""Test operator creation submission."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_password_mismatch_shows_error(self):
|
||||||
|
"""POST with password mismatch re-renders with error."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf"
|
||||||
|
mock_request.form = AsyncMock(return_value={
|
||||||
|
"csrf_token": "test_csrf",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password1",
|
||||||
|
"confirm_password": "password2", # Mismatch
|
||||||
|
})
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.return_value = 0 # No existing operators
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||||
|
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value.csrf_secret = "testsecret"
|
||||||
|
result = await setup_operator_submit(
|
||||||
|
mock_request,
|
||||||
|
username="testuser",
|
||||||
|
password="password1",
|
||||||
|
confirm_password="password2",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert context["error"] == "Passwords do not match"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_creates_operator_and_redirects(self):
|
||||||
|
"""POST with valid data creates operator and redirects to /setup/system."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.csrf_token = "test_csrf"
|
||||||
|
mock_request.form = AsyncMock(return_value={
|
||||||
|
"csrf_token": "test_csrf",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"confirm_password": "password123",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.return_value = 0 # No existing operators
|
||||||
|
mock_conn.fetchrow.side_effect = [
|
||||||
|
{"id": 1}, # INSERT RETURNING id
|
||||||
|
{"session_lifetime_days": 90}, # system settings
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||||
|
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value.csrf_secret = "testsecret"
|
||||||
|
with patch("central.gui.routes.hash_password", return_value="hashed"):
|
||||||
|
with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session:
|
||||||
|
mock_session.return_value = ("session_token", datetime.now(), "csrf_token")
|
||||||
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
|
result = await setup_operator_submit(
|
||||||
|
mock_request,
|
||||||
|
username="testuser",
|
||||||
|
password="password123",
|
||||||
|
confirm_password="password123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/system"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_when_operator_exists_shows_confirmation(self):
|
||||||
|
"""POST when operator exists returns 200 with confirmation, no insert."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.form = AsyncMock(return_value={
|
||||||
|
"csrf_token": "test_csrf",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"confirm_password": "password123",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.return_value = 1 # Operator already exists
|
||||||
|
mock_conn.fetchrow.return_value = {"username": "existing_admin"}
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
mock_request.state.csrf_token = "test_csrf"
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
|
||||||
|
with patch("central.gui.routes.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value.csrf_secret = "testsecret"
|
||||||
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
|
result = await setup_operator_submit(
|
||||||
|
mock_request,
|
||||||
|
username="testuser",
|
||||||
|
password="password123",
|
||||||
|
confirm_password="password123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 200, not 500 or redirect
|
||||||
|
assert result.status_code == 200
|
||||||
|
|
||||||
|
# Should render confirmation state
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert context["existing_operator"] == {"username": "existing_admin"}
|
||||||
|
|
||||||
|
# Should NOT call write_audit (no insert happened)
|
||||||
|
mock_audit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupSystemForm:
|
||||||
|
"""Test system settings form (step 2)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthenticated_redirects_to_operator(self):
|
||||||
|
"""GET /setup/system without auth redirects to /setup/operator."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = None
|
||||||
|
result = await setup_system_form(mock_request)
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticated_returns_form(self):
|
||||||
|
"""GET /setup/system with auth returns the form."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
|
||||||
|
"map_attribution": "© OpenStreetMap contributors",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
result = await setup_system_form(mock_request)
|
||||||
|
|
||||||
|
mock_templates.TemplateResponse.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupSystemSubmit:
|
||||||
|
"""Test system settings submission."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_placeholders_shows_error(self):
|
||||||
|
"""POST without {z},{x},{y} placeholders shows error."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
|
form_data = MagicMock()
|
||||||
|
form_data.get = lambda k, default="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"map_tile_url": "https://example.com/tiles",
|
||||||
|
"map_attribution": "Test",
|
||||||
|
}.get(k, default)
|
||||||
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "",
|
||||||
|
"map_attribution": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
result = await setup_system_submit(mock_request)
|
||||||
|
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert "map_tile_url" in context["errors"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_updates_and_redirects(self):
|
||||||
|
"""POST with valid data updates system and redirects to /setup/keys."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
|
form_data = MagicMock()
|
||||||
|
form_data.get = lambda k, default="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"map_tile_url": "https://example.com/{z}/{x}/{y}.png",
|
||||||
|
"map_attribution": "Test Attribution",
|
||||||
|
}.get(k, default)
|
||||||
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.return_value = {
|
||||||
|
"map_tile_url": "old_url",
|
||||||
|
"map_attribution": "old_attr",
|
||||||
|
}
|
||||||
|
mock_conn.execute = AsyncMock()
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
|
result = await setup_system_submit(mock_request)
|
||||||
|
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/keys"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupKeysForm:
|
||||||
|
"""Test API keys form (step 3)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthenticated_redirects_to_operator(self):
|
||||||
|
"""GET /setup/keys without auth redirects to /setup/operator."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = None
|
||||||
|
result = await setup_keys_form(mock_request)
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupKeysSubmit:
|
||||||
|
"""Test API keys submission."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_next_action_redirects_to_adapters(self):
|
||||||
|
"""POST with action=next redirects to /setup/adapters."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
|
form_data = MagicMock()
|
||||||
|
form_data.get = lambda k, default="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"action": "next",
|
||||||
|
}.get(k, default)
|
||||||
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
# No need to mock get_pool since action="next" returns before it's called
|
||||||
|
result = await setup_keys_submit(mock_request)
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/adapters"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_key_creates_and_rerenders(self):
|
||||||
|
"""POST with action=add creates key and re-renders with success."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
|
form_data = MagicMock()
|
||||||
|
form_data.get = lambda k, default="": {
|
||||||
|
"csrf_token": "test_csrf_token",
|
||||||
|
"action": "add",
|
||||||
|
"alias": "testkey",
|
||||||
|
"plaintext_key": "secret123",
|
||||||
|
}.get(k, default)
|
||||||
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchrow.side_effect = [
|
||||||
|
None, # No existing key
|
||||||
|
{"created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)},
|
||||||
|
]
|
||||||
|
mock_conn.fetch.side_effect = [
|
||||||
|
[], # First list
|
||||||
|
[{"alias": "testkey", "created_at": datetime(2026, 5, 18, 12, 0, tzinfo=timezone.utc)}], # After insert
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.crypto.encrypt", return_value=b"encrypted"):
|
||||||
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||||
|
result = await setup_keys_submit(mock_request)
|
||||||
|
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert context["success"] == "API key 'testkey' added successfully."
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupAdaptersForm:
|
||||||
|
"""Test adapters configuration form (step 4)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthenticated_redirects_to_operator(self):
|
||||||
|
"""GET /setup/adapters without auth redirects to /setup/operator."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = None
|
||||||
|
result = await setup_adapters_form(mock_request)
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupFinishForm:
|
||||||
|
"""Test finish page (step 5)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthenticated_redirects_to_operator(self):
|
||||||
|
"""GET /setup/finish without auth redirects to /setup/operator."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = None
|
||||||
|
result = await setup_finish_form(mock_request)
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticated_shows_summary(self):
|
||||||
|
"""GET /setup/finish with auth shows summary."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
|
||||||
|
mock_templates = MagicMock()
|
||||||
|
mock_templates.TemplateResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.fetchval.side_effect = [1, 2] # 1 operator, 2 keys
|
||||||
|
mock_conn.fetchrow.return_value = {"map_tile_url": "https://example.com/{z}/{x}/{y}.png"}
|
||||||
|
mock_conn.fetch.return_value = [
|
||||||
|
{"name": "nws", "enabled": True, "cadence_s": 300},
|
||||||
|
{"name": "firms", "enabled": False, "cadence_s": 600},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
result = await setup_finish_form(mock_request)
|
||||||
|
|
||||||
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||||
|
assert context["operator_count"] == 1
|
||||||
|
assert context["key_count"] == 2
|
||||||
|
assert len(context["adapters"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupFinishSubmit:
|
||||||
|
"""Test setup completion."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_marks_setup_complete_and_redirects(self):
|
||||||
|
"""POST /setup/finish marks setup_complete=true and redirects to /."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.operator = MagicMock(id=1, username="admin")
|
||||||
|
mock_request.state.csrf_token = "test_csrf_token"
|
||||||
|
|
||||||
|
# Mock form with CSRF token
|
||||||
|
form_data = MagicMock()
|
||||||
|
form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default)
|
||||||
|
mock_request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock()
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.acquire.return_value.__aenter__.return_value = mock_conn
|
||||||
|
mock_pool.acquire.return_value.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||||
|
result = await setup_finish_submit(mock_request)
|
||||||
|
|
||||||
|
assert result.status_code == 302
|
||||||
|
assert result.headers["location"] == "/"
|
||||||
|
mock_conn.execute.assert_called_once()
|
||||||
|
mock_audit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupGateMiddlewareWizard:
|
||||||
|
"""Test SetupGateMiddleware with wizard paths."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_setup_operator_when_incomplete(self):
|
||||||
|
"""SetupGateMiddleware allows /setup/operator when setup_complete=False."""
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator form"}
|
||||||
|
|
||||||
|
app.add_middleware(SetupGateMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/setup/operator")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirects_base_setup_to_wizard_step(self):
|
||||||
|
"""SetupGateMiddleware redirects /setup to appropriate wizard step."""
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=0) # No operators
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/setup")
|
||||||
|
async def setup():
|
||||||
|
return {"message": "base setup"}
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator"}
|
||||||
|
|
||||||
|
app.add_middleware(SetupGateMiddleware)
|
||||||
|
client = TestClient(app, follow_redirects=False)
|
||||||
|
|
||||||
|
response = client.get("/setup")
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "/setup/operator"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirects_login_to_setup_when_incomplete(self):
|
||||||
|
"""SetupGateMiddleware redirects /login to /setup when setup_complete=False."""
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": False})
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/login")
|
||||||
|
async def login():
|
||||||
|
return {"message": "login"}
|
||||||
|
|
||||||
|
@app.get("/setup")
|
||||||
|
async def setup():
|
||||||
|
return {"message": "setup"}
|
||||||
|
|
||||||
|
app.add_middleware(SetupGateMiddleware)
|
||||||
|
client = TestClient(app, follow_redirects=False)
|
||||||
|
|
||||||
|
response = client.get("/login")
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "/setup"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirects_all_setup_paths_when_complete(self):
|
||||||
|
"""SetupGateMiddleware redirects /setup/* to / when setup_complete=True."""
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(return_value={"setup_complete": True})
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def index():
|
||||||
|
return {"message": "home"}
|
||||||
|
|
||||||
|
@app.get("/setup/operator")
|
||||||
|
async def setup_operator():
|
||||||
|
return {"message": "operator"}
|
||||||
|
|
||||||
|
app.add_middleware(SetupGateMiddleware)
|
||||||
|
client = TestClient(app, follow_redirects=False)
|
||||||
|
|
||||||
|
response = client.get("/setup/operator")
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "/"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue