Merge refactor/a3a-generic-wizard: generic wizard with Literal types

- Add Literal type support to form_descriptors (select/checkboxes widgets)
- Refactor wizard to use wizard_order for adapter filtering
- Replace hardcoded adapter lists with dynamic discovery
- Move contact_email validation to Pydantic pattern
- Add generic api_key_field mechanism
- Remove all field.name hardcoded branches
- 335 tests passing
This commit is contained in:
Matt Johnson 2026-05-19 01:08:35 +00:00
commit 43bf973caf
12 changed files with 953 additions and 239 deletions

View file

@ -34,6 +34,10 @@ class SourceAdapter(ABC):
description: str description: str
settings_schema: type[BaseModel] settings_schema: type[BaseModel]
requires_api_key: str | None = None requires_api_key: str | None = None
api_key_field: str | None = None
"""Names the settings_schema field that holds an api_key alias reference, if any.
The GUI renders this field as a select populated from config.api_keys;
the wizard validates it against staged api_keys state."""
wizard_order: int | None = None wizard_order: int | None = None
default_cadence_s: int default_cadence_s: int

View file

@ -7,7 +7,7 @@ from collections.abc import AsyncIterator
from datetime import datetime, timezone from datetime import datetime, timezone
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
import aiohttp import aiohttp
from tenacity import ( from tenacity import (
@ -54,7 +54,7 @@ SEVERITY_MAP = {
class FIRMSSettings(BaseModel): class FIRMSSettings(BaseModel):
"""Settings schema for FIRMS adapter.""" """Settings schema for FIRMS adapter."""
api_key_alias: str = "firms" api_key_alias: str = "firms"
satellites: list[str] = ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"] satellites: list[Literal["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT", "VIIRS_NOAA21_NRT"]] = ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
region: RegionConfig | None = None region: RegionConfig | None = None
@ -66,6 +66,7 @@ class FIRMSAdapter(SourceAdapter):
description = "Near-real-time satellite-detected fire hotspots from NASA FIRMS." description = "Near-real-time satellite-detected fire hotspots from NASA FIRMS."
settings_schema = FIRMSSettings settings_schema = FIRMSSettings
requires_api_key = "firms" requires_api_key = "firms"
api_key_field = "api_key_alias"
wizard_order = 2 wizard_order = 2
default_cadence_s = 300 default_cadence_s = 300

View file

@ -19,7 +19,7 @@ from tenacity import (
from central import __version__ from central import __version__
from central.adapter import SourceAdapter from central.adapter import SourceAdapter
from pydantic import BaseModel from pydantic import BaseModel, Field
from central.config_models import AdapterConfig, RegionConfig from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore from central.config_store import ConfigStore
@ -193,7 +193,11 @@ def _build_regions(same_codes: list[str], ugc_codes: list[str]) -> list[str]:
class NWSSettings(BaseModel): class NWSSettings(BaseModel):
"""Settings schema for NWS adapter.""" """Settings schema for NWS adapter."""
contact_email: str = "" contact_email: str = Field(
default="",
pattern=r"^$|^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
description="Contact email for NWS API User-Agent header",
)
region: RegionConfig | None = None region: RegionConfig | None = None

View file

@ -5,7 +5,7 @@ import sqlite3
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
import aiohttp import aiohttp
from shapely.geometry import Point, box as shapely_box from shapely.geometry import Point, box as shapely_box
@ -64,7 +64,7 @@ def magnitude_to_severity(mag: float) -> int:
class USGSQuakeSettings(BaseModel): class USGSQuakeSettings(BaseModel):
"""Settings schema for USGS quake adapter.""" """Settings schema for USGS quake adapter."""
feed: str = "all_hour" feed: Literal["all_hour", "all_day", "all_week", "all_month"] = "all_hour"
region: RegionConfig | None = None region: RegionConfig | None = None

View file

@ -247,18 +247,37 @@ def _create_app() -> FastAPI:
except Exception: except Exception:
pass pass
# Import helper functions for valid values # Add field descriptors to adapters
from central.gui.routes import _get_valid_satellites, _get_valid_feeds from central.gui.routes import _adapter_classes
from central.gui.form_descriptors import describe_fields
adapter_classes = _adapter_classes()
wizard_adapters = sorted(
[(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None],
key=lambda nc: nc[1].wizard_order
)
# Rebuild adapters with fields
enriched_adapters = []
for name, cls in wizard_adapters:
adapter_data = next((a for a in adapters if a["name"] == name), None)
if adapter_data:
settings_dict = adapter_data.get("settings", {})
fields = describe_fields(cls.settings_schema, settings_dict)
enriched_adapters.append({
"name": name,
"display_name": cls.display_name,
"enabled": adapter_data.get("enabled", False),
"cadence_s": adapter_data.get("cadence_s", 300),
"settings": settings_dict,
"fields": fields,
})
response = templates.TemplateResponse( response = templates.TemplateResponse(
request=request, request=request,
name="setup_adapters.html", name="setup_adapters.html",
context={ context={
"csrf_token": csrf_token, "csrf_token": csrf_token,
"adapters": adapters, "adapters": enriched_adapters,
"api_keys": api_keys, "api_keys": api_keys,
"valid_satellites": _get_valid_satellites(),
"valid_feeds": sorted(_get_valid_feeds()),
"tile_url": tile_url, "tile_url": tile_url,
"tile_attribution": tile_attribution, "tile_attribution": tile_attribution,
"error": error_msg, "error": error_msg,

View file

@ -4,8 +4,8 @@ If a second nested settings type beyond RegionConfig appears,
refactor this helper to recurse over nested models. refactor this helper to recurse over nested models.
""" """
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Union, get_args, get_origin from typing import Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -19,15 +19,30 @@ class FieldDescriptor:
"""Describes a form field for rendering.""" """Describes a form field for rendering."""
name: str name: str
label: str label: str
widget: str # "text", "number", "checkbox", "csv", "region" widget: str # "text", "number", "checkbox", "csv", "select", "checkboxes", "region"
current_value: Any current_value: Any
default: Any default: Any
description: str description: str
required: bool required: bool
options: list[str] | None = None # For select/checkboxes widgets
def _type_to_widget(field_name: str, field_type: type) -> str: def _is_literal(tp: type) -> bool:
"""Map a Python type to a widget type.""" """Check if a type is a Literal type."""
return get_origin(tp) is Literal
def _get_literal_values(tp: type) -> list[str]:
"""Extract the literal values from a Literal type."""
return list(get_args(tp))
def _type_to_widget_and_options(field_name: str, field_type: type) -> tuple[str, list[str] | None]:
"""Map a Python type to a widget type and optional options list.
Returns:
Tuple of (widget_type, options_list_or_none)
"""
# Handle Optional/Union types # Handle Optional/Union types
origin = get_origin(field_type) origin = get_origin(field_type)
args = get_args(field_type) args = get_args(field_type)
@ -39,24 +54,38 @@ def _type_to_widget(field_name: str, field_type: type) -> str:
if non_none_args: if non_none_args:
inner_type = non_none_args[0] inner_type = non_none_args[0]
# Recursively determine widget for the inner type # Recursively determine widget for the inner type
return _type_to_widget(field_name, inner_type) return _type_to_widget_and_options(field_name, inner_type)
# Check for Literal type (single select)
if _is_literal(field_type):
options = _get_literal_values(field_type)
return "select", [str(o) for o in options]
# Direct type checks # Direct type checks
if field_type is str: if field_type is str:
return "text" return "text", None
if field_type is int: if field_type is int:
return "number" return "number", None
if field_type is bool: if field_type is bool:
return "checkbox" return "checkbox", None
if field_type is RegionConfig: if field_type is RegionConfig:
return "region" return "region", None
# Check for list[str] # Check for list types
if origin is list: if origin is list:
if args and args[0] is str: inner_type = args[0] if args else None
return "csv"
# list[Literal[...]] -> checkboxes
if inner_type is not None and _is_literal(inner_type):
options = _get_literal_values(inner_type)
return "checkboxes", [str(o) for o in options]
# list[str] -> csv
if inner_type is str:
return "csv", None
raise NotImplementedError( raise NotImplementedError(
f"Field '{field_name}' has unsupported list type: list[{args[0].__name__ if args else '?'}]" f"Field '{field_name}' has unsupported list type: list[{inner_type.__name__ if inner_type else '?'}]"
) )
# Check if it's a BaseModel subclass (nested model other than RegionConfig) # Check if it's a BaseModel subclass (nested model other than RegionConfig)
@ -98,8 +127,8 @@ def describe_fields(model_cls: type[BaseModel], current: dict) -> list[FieldDesc
# Get the field type # Get the field type
field_type = field_info.annotation field_type = field_info.annotation
# Determine widget # Determine widget and options
widget = _type_to_widget(field_name, field_type) widget, options = _type_to_widget_and_options(field_name, field_type)
# Get current value, falling back to default # Get current value, falling back to default
if field_name in current: if field_name in current:
@ -128,6 +157,7 @@ def describe_fields(model_cls: type[BaseModel], current: dict) -> list[FieldDesc
default=default, default=default,
description=description, description=description,
required=required, required=required,
options=options,
)) ))
return descriptors return descriptors

View file

@ -73,18 +73,6 @@ ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$")
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
def _get_valid_satellites() -> list[str]:
"""Get valid satellite identifiers from firms adapter."""
from central.adapters.firms import SATELLITE_SHORT
return list(SATELLITE_SHORT.keys())
def _get_valid_feeds() -> set[str]:
"""Get valid feed values from usgs_quake adapter."""
from central.adapters.usgs_quake import VALID_FEEDS
return VALID_FEEDS
def _get_templates(): def _get_templates():
"""Get templates instance (deferred import to avoid circular).""" """Get templates instance (deferred import to avoid circular)."""
from central.gui import templates from central.gui import templates
@ -647,18 +635,36 @@ async def setup_adapters_form(request: Request) -> HTMLResponse:
templates = _get_templates() templates = _get_templates()
pool = get_pool() pool = get_pool()
# Get wizard adapters (filtered by wizard_order)
adapter_classes = _adapter_classes()
wizard_adapters = sorted(
[(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None],
key=lambda nc: nc[1].wizard_order
)
# Pre-fill from cookie state or DB defaults # Pre-fill from cookie state or DB defaults
if state.adapters: if state.adapters:
adapters = [] adapters = []
for name in ["firms", "nws", "usgs_quake"]: for name, cls in wizard_adapters:
if name in state.adapters: if name in state.adapters:
a = state.adapters[name] a = state.adapters[name]
adapters.append({ settings_dict = a["settings"]
"name": name, else:
"enabled": a["enabled"], settings_dict = {}
"cadence_s": a["cadence_s"], fields = describe_fields(cls.settings_schema, settings_dict)
"settings": a["settings"], # Swap widget for api_key_field to api_key_select
}) if cls.api_key_field is not None:
for f in fields:
if f.name == cls.api_key_field:
f.widget = "api_key_select"
adapters.append({
"name": name,
"display_name": cls.display_name,
"enabled": a["enabled"] if name in state.adapters else False,
"cadence_s": a["cadence_s"] if name in state.adapters else 300,
"settings": settings_dict,
"fields": fields,
})
else: else:
async with pool.acquire() as conn: async with pool.acquire() as conn:
rows = await conn.fetch( rows = await conn.fetch(
@ -668,15 +674,33 @@ async def setup_adapters_form(request: Request) -> HTMLResponse:
ORDER BY name ORDER BY name
""" """
) )
adapters = [] db_adapters = {row["name"]: row for row in rows}
for row in rows:
settings_data = row["settings"] or {} adapters = []
adapters.append({ for name, cls in wizard_adapters:
"name": row["name"], if name in db_adapters:
"enabled": row["enabled"], row = db_adapters[name]
"cadence_s": row["cadence_s"], settings_dict = row["settings"] or {}
"settings": settings_data, enabled = row["enabled"]
}) cadence_s = row["cadence_s"]
else:
settings_dict = {}
enabled = False
cadence_s = 300
fields = describe_fields(cls.settings_schema, settings_dict)
# Swap widget for api_key_field to api_key_select
if cls.api_key_field is not None:
for f in fields:
if f.name == cls.api_key_field:
f.widget = "api_key_select"
adapters.append({
"name": name,
"display_name": cls.display_name,
"enabled": enabled,
"cadence_s": cadence_s,
"settings": settings_dict,
"fields": fields,
})
# Get API keys from wizard state (not DB) # Get API keys from wizard state (not DB)
api_keys = [{"alias": k["alias"]} for k in state.api_keys] api_keys = [{"alias": k["alias"]} for k in state.api_keys]
@ -701,8 +725,6 @@ async def setup_adapters_form(request: Request) -> HTMLResponse:
"csrf_token": csrf_token, "csrf_token": csrf_token,
"adapters": adapters, "adapters": adapters,
"api_keys": api_keys, "api_keys": api_keys,
"valid_satellites": _get_valid_satellites(),
"valid_feeds": sorted(_get_valid_feeds()),
"tile_url": tile_url, "tile_url": tile_url,
"tile_attribution": tile_attribution, "tile_attribution": tile_attribution,
"error": None, "error": None,
@ -755,7 +777,14 @@ async def setup_adapters_submit(request: Request) -> Response:
"settings": row["settings"] or {}, "settings": row["settings"] or {},
} }
for adapter_name in ["firms", "nws", "usgs_quake"]: # Get wizard adapters (filtered by wizard_order)
adapter_classes = _adapter_classes()
wizard_adapters = sorted(
[(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None],
key=lambda nc: nc[1].wizard_order
)
for adapter_name, adapter_cls in wizard_adapters:
current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}})
current_settings = current.get("settings", {}) current_settings = current.get("settings", {})
new_settings = dict(current_settings) new_settings = dict(current_settings)
@ -777,73 +806,94 @@ async def setup_adapters_submit(request: Request) -> Response:
errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer"
cadence_s = current.get("cadence_s", 300) cadence_s = current.get("cadence_s", 300)
# Adapter-specific validation # Generic field parsing using describe_fields
if adapter_name == "nws": fields = describe_fields(adapter_cls.settings_schema, current_settings)
contact_email = form.get(f"{adapter_name}_contact_email", "").strip() for field in fields:
if enabled: form_key = f"{adapter_name}_{field.name}"
if not contact_email:
errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" if field.widget == "text":
elif not EMAIL_REGEX.match(contact_email): value = form.get(form_key, "").strip()
errors[f"{adapter_name}_contact_email"] = "Invalid email format" new_settings[field.name] = value if value else current_settings.get(field.name)
elif field.widget == "api_key_select":
# API key alias field - stored as text, validated post-loop
value = form.get(form_key, "").strip()
new_settings[field.name] = value if value else None
elif field.widget == "number":
value_str = form.get(form_key, "").strip()
if value_str:
try:
new_settings[field.name] = int(value_str)
except ValueError:
errors[form_key] = f"{field.label} must be a valid number"
else: else:
new_settings["contact_email"] = contact_email new_settings[field.name] = current_settings.get(field.name)
else:
new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email")
elif adapter_name == "firms": elif field.widget == "checkbox":
api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() new_settings[field.name] = form_key in form
satellites = form.getlist(f"{adapter_name}_satellites")
if api_key_alias: elif field.widget == "csv":
# Validate against wizard state keys value = form.get(form_key, "").strip()
if not any(k["alias"] == api_key_alias for k in state.api_keys): if value:
errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" new_settings[field.name] = [v.strip() for v in value.split(",") if v.strip()]
else: else:
new_settings["api_key_alias"] = api_key_alias new_settings[field.name] = []
else:
new_settings["api_key_alias"] = None
# Validate satellites elif field.widget == "select":
valid_sats = set(_get_valid_satellites()) value = form.get(form_key, "").strip()
invalid_sats = [s for s in satellites if s not in valid_sats] if value and field.options and value not in field.options:
if invalid_sats: errors[form_key] = f"Invalid {field.label.lower()}"
errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) else:
else: new_settings[field.name] = value
new_settings["satellites"] = satellites
elif adapter_name == "usgs_quake": elif field.widget == "checkboxes":
feed = form.get(f"{adapter_name}_feed", "").strip() # Use getlist for checkbox groups - absence means empty list
valid_feeds = _get_valid_feeds() values = form.getlist(form_key)
if feed not in valid_feeds: if field.options:
errors[f"{adapter_name}_feed"] = "Invalid feed" invalid = [v for v in values if v not in field.options]
else: if invalid:
new_settings["feed"] = feed errors[form_key] = f"Invalid values: {', '.join(invalid)}"
else:
new_settings[field.name] = values
else:
new_settings[field.name] = values
# Region validation (all adapters) elif field.widget == "region":
region_north_str = form.get(f"{adapter_name}_region_north", "").strip() # Region validation via RegionConfig model
region_south_str = form.get(f"{adapter_name}_region_south", "").strip() from central.config_models import RegionConfig
region_east_str = form.get(f"{adapter_name}_region_east", "").strip() region_north_str = form.get(f"{adapter_name}_{field.name}_north", "").strip()
region_west_str = form.get(f"{adapter_name}_region_west", "").strip() region_south_str = form.get(f"{adapter_name}_{field.name}_south", "").strip()
region_east_str = form.get(f"{adapter_name}_{field.name}_east", "").strip()
region_west_str = form.get(f"{adapter_name}_{field.name}_west", "").strip()
try:
region_model = RegionConfig(
north=float(region_north_str),
south=float(region_south_str),
east=float(region_east_str),
west=float(region_west_str),
)
new_settings[field.name] = region_model.model_dump()
except (ValueError, ValidationError) as e:
errors[f"{adapter_name}_{field.name}"] = str(e)
# Run Pydantic validation on assembled settings to catch Literal violations etc.
try: try:
region_north = float(region_north_str) adapter_cls.settings_schema(**new_settings)
region_south = float(region_south_str) except ValidationError as e:
region_east = float(region_east_str) for err in e.errors():
region_west = float(region_west_str) loc = err["loc"][0] if err["loc"] else "unknown"
errors[f"{adapter_name}_{loc}"] = err["msg"]
if not (-90 <= region_south < region_north <= 90): # Generic api_key_field validation against wizard state
errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" if adapter_cls.api_key_field is not None:
elif not (-180 <= region_west < region_east <= 180): field_value = new_settings.get(adapter_cls.api_key_field)
errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" if field_value:
else: if not any(k["alias"] == field_value for k in state.api_keys):
new_settings["region"] = { errors[f"{adapter_name}_{adapter_cls.api_key_field}"] = (
"north": region_north, "API key alias does not exist"
"south": region_south, )
"east": region_east,
"west": region_west,
}
except ValueError:
errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers"
new_adapters[adapter_name] = { new_adapters[adapter_name] = {
"enabled": enabled, "enabled": enabled,
@ -853,12 +903,23 @@ async def setup_adapters_submit(request: Request) -> Response:
# If errors, re-render # If errors, re-render
if errors: if errors:
adapters = [ adapters = []
{"name": name, "enabled": new_adapters[name]["enabled"], for name, cls in wizard_adapters:
"cadence_s": new_adapters[name]["cadence_s"], settings_dict = new_adapters[name]["settings"]
"settings": new_adapters[name]["settings"]} fields = describe_fields(cls.settings_schema, settings_dict)
for name in ["firms", "nws", "usgs_quake"] # Swap widget for api_key_field to api_key_select
] if cls.api_key_field is not None:
for f in fields:
if f.name == cls.api_key_field:
f.widget = "api_key_select"
adapters.append({
"name": name,
"display_name": cls.display_name,
"enabled": new_adapters[name]["enabled"],
"cadence_s": new_adapters[name]["cadence_s"],
"settings": settings_dict,
"fields": fields,
})
api_keys = [{"alias": k["alias"]} for k in state.api_keys] api_keys = [{"alias": k["alias"]} for k in state.api_keys]
if state.system: if state.system:
@ -876,8 +937,6 @@ async def setup_adapters_submit(request: Request) -> Response:
"csrf_token": csrf_token, "csrf_token": csrf_token,
"adapters": adapters, "adapters": adapters,
"api_keys": api_keys, "api_keys": api_keys,
"valid_satellites": _get_valid_satellites(),
"valid_feeds": sorted(_get_valid_feeds()),
"tile_url": tile_url, "tile_url": tile_url,
"tile_attribution": tile_attribution, "tile_attribution": tile_attribution,
"error": "Please fix the errors below.", "error": "Please fix the errors below.",
@ -918,10 +977,20 @@ async def setup_finish_form(request: Request) -> HTMLResponse:
adapters = [] adapters = []
if state.adapters: if state.adapters:
for name in ["firms", "nws", "usgs_quake"]: adapter_classes = _adapter_classes()
wizard_adapters = sorted(
[(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None],
key=lambda nc: nc[1].wizard_order
)
for name, cls in wizard_adapters:
if name in state.adapters: if name in state.adapters:
a = state.adapters[name] a = state.adapters[name]
adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) adapters.append({
"name": name,
"display_name": cls.display_name,
"enabled": a["enabled"],
"cadence_s": a["cadence_s"],
})
csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret)
response = templates.TemplateResponse( response = templates.TemplateResponse(
@ -1338,6 +1407,17 @@ async def adapters_edit_form(
fields = [] fields = []
if adapter_cls and hasattr(adapter_cls, "settings_schema"): if adapter_cls and hasattr(adapter_cls, "settings_schema"):
fields = describe_fields(adapter_cls.settings_schema, settings) fields = describe_fields(adapter_cls.settings_schema, settings)
# Swap widget for api_key_field to api_key_select
if adapter_cls.api_key_field is not None:
for f in fields:
if f.name == adapter_cls.api_key_field:
f.widget = "api_key_select"
# Fetch API keys for api_key_select widget
api_keys = []
async with pool.acquire() as conn:
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
csrf_token = request.state.csrf_token csrf_token = request.state.csrf_token
response = templates.TemplateResponse( response = templates.TemplateResponse(
@ -1348,6 +1428,7 @@ async def adapters_edit_form(
"csrf_token": csrf_token, "csrf_token": csrf_token,
"adapter": adapter, "adapter": adapter,
"fields": fields, "fields": fields,
"api_keys": api_keys,
"errors": None, "errors": None,
"form_data": None, "form_data": None,
"tile_url": tile_url, "tile_url": tile_url,
@ -1441,6 +1522,28 @@ async def adapters_edit_submit(
parsed_values[field.name] = [v.strip() for v in raw.split(",") if v.strip()] parsed_values[field.name] = [v.strip() for v in raw.split(",") if v.strip()]
else: else:
parsed_values[field.name] = [] parsed_values[field.name] = []
elif field.widget == "select":
value = raw.strip() if raw else None
if value and field.options and value not in field.options:
errors[field.name] = f"Invalid {field.label.lower()}"
else:
parsed_values[field.name] = value
elif field.widget == "checkboxes":
# Use getlist for checkbox groups
values = form.getlist(field.name)
form_data[field.name] = values # Override raw value
if field.options:
invalid = [v for v in values if v not in field.options]
if invalid:
errors[field.name] = f"Invalid values: {', '.join(invalid)}"
else:
parsed_values[field.name] = values
else:
parsed_values[field.name] = values
elif field.widget == "api_key_select":
# API key select - validate against existing keys
value = raw.strip() if raw else None
parsed_values[field.name] = value
elif field.widget == "region": elif field.widget == "region":
# Region handled separately below # Region handled separately below
pass pass
@ -1521,6 +1624,15 @@ async def adapters_edit_submit(
fields = [] fields = []
if adapter_cls and hasattr(adapter_cls, "settings_schema"): if adapter_cls and hasattr(adapter_cls, "settings_schema"):
fields = describe_fields(adapter_cls.settings_schema, current_settings) fields = describe_fields(adapter_cls.settings_schema, current_settings)
# Swap widget for api_key_field to api_key_select
if adapter_cls.api_key_field is not None:
for f in fields:
if f.name == adapter_cls.api_key_field:
f.widget = "api_key_select"
# Fetch API keys for api_key_select widget
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
csrf_token = request.state.csrf_token csrf_token = request.state.csrf_token
response = templates.TemplateResponse( response = templates.TemplateResponse(
@ -1531,6 +1643,7 @@ async def adapters_edit_submit(
"csrf_token": csrf_token, "csrf_token": csrf_token,
"adapter": adapter, "adapter": adapter,
"fields": fields, "fields": fields,
"api_keys": api_keys,
"errors": errors, "errors": errors,
"form_data": form_data, "form_data": form_data,
"tile_url": tile_url, "tile_url": tile_url,

View file

@ -100,6 +100,58 @@
{% if errors and errors[field.name] %} {% if errors and errors[field.name] %}
<small style="color: var(--pico-color-red-500);">{{ errors[field.name] }}</small> <small style="color: var(--pico-color-red-500);">{{ errors[field.name] }}</small>
{% endif %} {% endif %}
{% elif field.widget == "select" %}
<label for="{{ field.name }}">{{ field.label }}</label>
<select id="{{ field.name }}" name="{{ field.name }}">
{% for opt in field.options %}
<option value="{{ opt }}"
{% if (form_data[field.name] if form_data and field.name in form_data else field.current_value) == opt %}selected{% endif %}>
{{ opt }}
</option>
{% endfor %}
</select>
{% if field.description %}
<small>{{ field.description }}</small>
{% endif %}
{% if errors and errors[field.name] %}
<small style="color: var(--pico-color-red-500);">{{ errors[field.name] }}</small>
{% endif %}
{% elif field.widget == "checkboxes" %}
<label>{{ field.label }}</label>
{% set current_values = form_data.getlist(field.name) if form_data and form_data.getlist else (field.current_value or []) %}
{% for opt in field.options %}
<label style="display: inline-block; margin-right: 1rem;">
<input type="checkbox" name="{{ field.name }}" value="{{ opt }}"
{% if opt in current_values %}checked{% endif %}>
{{ opt }}
</label>
{% endfor %}
{% if field.description %}
<small style="display: block;">{{ field.description }}</small>
{% endif %}
{% if errors and errors[field.name] %}
<small style="color: var(--pico-color-red-500); display: block;">{{ errors[field.name] }}</small>
{% endif %}
{% elif field.widget == "api_key_select" %}
<label for="{{ field.name }}">{{ field.label }}</label>
<select id="{{ field.name }}" name="{{ field.name }}">
<option value="">(none)</option>
{% for key in api_keys %}
<option value="{{ key.alias }}"
{% if (form_data[field.name] if form_data and field.name in form_data else field.current_value) == key.alias %}selected{% endif %}>
{{ key.alias }}
</option>
{% endfor %}
</select>
{% if field.description %}
<small>{{ field.description }}</small>
{% endif %}
{% if errors and errors[field.name] %}
<small style="color: var(--pico-color-red-500);">{{ errors[field.name] }}</small>
{% endif %}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</fieldset> </fieldset>

View file

@ -29,7 +29,7 @@
{% for adapter in adapters %} {% for adapter in adapters %}
<details open style="margin-bottom: 2rem;"> <details open style="margin-bottom: 2rem;">
<summary><strong>{{ adapter.name }}</strong></summary> <summary><strong>{{ adapter.display_name or adapter.name }}</strong></summary>
<div style="padding: 1rem; border-left: 3px solid var(--pico-primary);"> <div style="padding: 1rem; border-left: 3px solid var(--pico-primary);">
<label> <label>
@ -44,100 +44,158 @@
<label for="{{ adapter.name }}_cadence_s">Cadence (seconds)</label> <label for="{{ adapter.name }}_cadence_s">Cadence (seconds)</label>
<input type="number" id="{{ adapter.name }}_cadence_s" name="{{ adapter.name }}_cadence_s" <input type="number" id="{{ adapter.name }}_cadence_s" name="{{ adapter.name }}_cadence_s"
value="{{ form_data.get(adapter.name + '_cadence_s') if form_data else adapter.cadence_s }}" value="{{ form_data.get(adapter.name + '_cadence_s') if form_data else adapter.cadence_s }}">
>
{% if errors and errors.get(adapter.name + '_cadence_s') %} {% if errors and errors.get(adapter.name + '_cadence_s') %}
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_cadence_s'] }}</small> <small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_cadence_s'] }}</small>
{% endif %} {% endif %}
{% if adapter.name == 'nws' %} {% for field in adapter.fields %}
<label for="{{ adapter.name }}_contact_email">Contact Email</label> {% set form_key = adapter.name + '_' + field.name %}
<input type="email" id="{{ adapter.name }}_contact_email" name="{{ adapter.name }}_contact_email"
value="{{ form_data.get(adapter.name + '_contact_email') if form_data else adapter.settings.contact_email }}">
{% if errors and errors.get(adapter.name + '_contact_email') %}
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_contact_email'] }}</small>
{% endif %}
{% endif %}
{% if adapter.name == 'firms' %} {% if field.widget == "text" %}
<label for="{{ adapter.name }}_api_key_alias">API Key Alias</label> <label for="{{ form_key }}">{{ field.label }}</label>
<select id="{{ adapter.name }}_api_key_alias" name="{{ adapter.name }}_api_key_alias"> <input type="text" id="{{ form_key }}" name="{{ form_key }}"
<option value="">(none)</option> value="{{ form_data.get(form_key) if form_data else field.current_value or '' }}"
{% for key in api_keys %} {% if field.required %}required{% endif %}>
<option value="{{ key.alias }}" {% if field.description %}
{% if (form_data.get(adapter.name + '_api_key_alias') if form_data else adapter.settings.api_key_alias) == key.alias %}selected{% endif %}> <small>{{ field.description }}</small>
{{ key.alias }} {% endif %}
</option> {% if errors and errors.get(form_key) %}
{% endfor %} <small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
</select> {% endif %}
{% if errors and errors.get(adapter.name + '_api_key_alias') %}
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_api_key_alias'] }}</small>
{% endif %}
<label>Satellites</label> {% elif field.widget == "api_key_select" %}
{% for sat in valid_satellites %} <label for="{{ form_key }}">{{ field.label }}</label>
<label style="display: inline-block; margin-right: 1rem;"> <select id="{{ form_key }}" name="{{ form_key }}">
<input type="checkbox" name="{{ adapter.name }}_satellites" value="{{ sat }}" <option value="">(none)</option>
{% if sat in (form_data.getlist(adapter.name + '_satellites') if form_data else adapter.settings.satellites or []) %}checked{% endif %}> {% for key in api_keys %}
{{ sat }} <option value="{{ key.alias }}"
</label> {% if (form_data.get(form_key) if form_data else field.current_value) == key.alias %}selected{% endif %}>
{% endfor %} {{ key.alias }}
{% endif %} </option>
{% endfor %}
</select>
{% if field.description %}
<small>{{ field.description }}</small>
{% endif %}
{% if errors and errors.get(form_key) %}
<small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
{% endif %}
{% if adapter.name == 'usgs_quake' %} {% elif field.widget == "number" %}
<label for="{{ adapter.name }}_feed">Feed</label> <label for="{{ form_key }}">{{ field.label }}</label>
<select id="{{ adapter.name }}_feed" name="{{ adapter.name }}_feed"> <input type="number" id="{{ form_key }}" name="{{ form_key }}"
{% for f in valid_feeds %} value="{{ form_data.get(form_key) if form_data else field.current_value or '' }}"
<option value="{{ f }}" {% if field.required %}required{% endif %}>
{% if (form_data.get(adapter.name + '_feed') if form_data else adapter.settings.feed) == f %}selected{% endif %}> {% if field.description %}
{{ f }} <small>{{ field.description }}</small>
</option> {% endif %}
{% endfor %} {% if errors and errors.get(form_key) %}
</select> <small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
{% if errors and errors.get(adapter.name + '_feed') %} {% endif %}
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_feed'] }}</small>
{% endif %}
{% endif %}
<h4>Region</h4> {% elif field.widget == "checkbox" %}
{% set region = form_data if form_data else adapter.settings.region %} <label>
<div id="region-picker-{{ adapter.name }}" <input type="checkbox" name="{{ form_key }}"
data-adapter="{{ adapter.name }}" {% if form_data and form_data.get(form_key) %}checked
data-north="{{ form_data.get(adapter.name + '_region_north') if form_data else (adapter.settings.region.north if adapter.settings.region else 49.5) }}" {% elif not form_data and field.current_value %}checked{% endif %}>
data-south="{{ form_data.get(adapter.name + '_region_south') if form_data else (adapter.settings.region.south if adapter.settings.region else 31.0) }}" {{ field.label }}
data-east="{{ form_data.get(adapter.name + '_region_east') if form_data else (adapter.settings.region.east if adapter.settings.region else -102.0) }}" </label>
data-west="{{ form_data.get(adapter.name + '_region_west') if form_data else (adapter.settings.region.west if adapter.settings.region else -124.5) }}" {% if field.description %}
data-tile-url="{{ tile_url }}" <small>{{ field.description }}</small>
data-tile-attr="{{ tile_attribution }}"> {% endif %}
{% if errors and errors.get(form_key) %}
<small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
{% endif %}
<div id="region-map-{{ adapter.name }}" style="height: 300px; margin-bottom: 1rem;"></div> {% elif field.widget == "csv" %}
<label for="{{ form_key }}">{{ field.label }}</label>
<input type="text" id="{{ form_key }}" name="{{ form_key }}"
value="{{ form_data.get(form_key) if form_data else (field.current_value | join(',') if field.current_value else '') }}"
{% if field.required %}required{% endif %}>
<small>Comma-separated values{% if field.description %} — {{ field.description }}{% endif %}</small>
{% if errors and errors.get(form_key) %}
<small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
{% endif %}
<div class="grid"> {% elif field.widget == "select" %}
<div> <label for="{{ form_key }}">{{ field.label }}</label>
<label>North</label> <select id="{{ form_key }}" name="{{ form_key }}">
<input type="number" name="{{ adapter.name }}_region_north" step="0.0001" min="-90" max="90" readonly {% for opt in field.options %}
value="{{ form_data.get(adapter.name + '_region_north') if form_data else (adapter.settings.region.north if adapter.settings.region else 49.5) }}"> <option value="{{ opt }}"
{% if (form_data.get(form_key) if form_data else field.current_value) == opt %}selected{% endif %}>
{{ opt }}
</option>
{% endfor %}
</select>
{% if field.description %}
<small>{{ field.description }}</small>
{% endif %}
{% if errors and errors.get(form_key) %}
<small style="color: var(--pico-color-red-500);">{{ errors[form_key] }}</small>
{% endif %}
{% elif field.widget == "checkboxes" %}
<label>{{ field.label }}</label>
{% set current_values = form_data.getlist(form_key) if form_data else (field.current_value or []) %}
{% for opt in field.options %}
<label style="display: inline-block; margin-right: 1rem;">
<input type="checkbox" name="{{ form_key }}" value="{{ opt }}"
{% if opt in current_values %}checked{% endif %}>
{{ opt }}
</label>
{% endfor %}
{% if field.description %}
<small style="display: block;">{{ field.description }}</small>
{% endif %}
{% if errors and errors.get(form_key) %}
<small style="color: var(--pico-color-red-500); display: block;">{{ errors[form_key] }}</small>
{% endif %}
{% elif field.widget == "region" %}
<h4>Region</h4>
{% set region_key = adapter.name + '_' + field.name %}
{% set region = field.current_value or {} %}
<div id="region-picker-{{ adapter.name }}"
data-adapter="{{ adapter.name }}"
data-field="{{ field.name }}"
data-north="{{ form_data.get(region_key + '_north') if form_data else (region.north if region else 49.5) }}"
data-south="{{ form_data.get(region_key + '_south') if form_data else (region.south if region else 31.0) }}"
data-east="{{ form_data.get(region_key + '_east') if form_data else (region.east if region else -102.0) }}"
data-west="{{ form_data.get(region_key + '_west') if form_data else (region.west if region else -124.5) }}"
data-tile-url="{{ tile_url }}"
data-tile-attr="{{ tile_attribution }}">
<div id="region-map-{{ adapter.name }}" style="height: 300px; margin-bottom: 1rem;"></div>
<div class="grid">
<div>
<label>North</label>
<input type="number" name="{{ region_key }}_north" step="0.0001" min="-90" max="90" readonly
value="{{ form_data.get(region_key + '_north') if form_data else (region.north if region else 49.5) }}">
</div>
<div>
<label>South</label>
<input type="number" name="{{ region_key }}_south" step="0.0001" min="-90" max="90" readonly
value="{{ form_data.get(region_key + '_south') if form_data else (region.south if region else 31.0) }}">
</div>
<div>
<label>East</label>
<input type="number" name="{{ region_key }}_east" step="0.0001" min="-180" max="180" readonly
value="{{ form_data.get(region_key + '_east') if form_data else (region.east if region else -102.0) }}">
</div>
<div>
<label>West</label>
<input type="number" name="{{ region_key }}_west" step="0.0001" min="-180" max="180" readonly
value="{{ form_data.get(region_key + '_west') if form_data else (region.west if region else -124.5) }}">
</div>
</div>
{% if errors and errors.get(region_key) %}
<small style="color: var(--pico-color-red-500);">{{ errors[region_key] }}</small>
{% endif %}
</div> </div>
<div>
<label>South</label>
<input type="number" name="{{ adapter.name }}_region_south" step="0.0001" min="-90" max="90" readonly
value="{{ form_data.get(adapter.name + '_region_south') if form_data else (adapter.settings.region.south if adapter.settings.region else 31.0) }}">
</div>
<div>
<label>East</label>
<input type="number" name="{{ adapter.name }}_region_east" step="0.0001" min="-180" max="180" readonly
value="{{ form_data.get(adapter.name + '_region_east') if form_data else (adapter.settings.region.east if adapter.settings.region else -102.0) }}">
</div>
<div>
<label>West</label>
<input type="number" name="{{ adapter.name }}_region_west" step="0.0001" min="-180" max="180" readonly
value="{{ form_data.get(adapter.name + '_region_west') if form_data else (adapter.settings.region.west if adapter.settings.region else -124.5) }}">
</div>
</div>
{% if errors and errors.get(adapter.name + '_region') %}
<small style="color: var(--pico-color-red-500);">{{ errors[adapter.name + '_region'] }}</small>
{% endif %} {% endif %}
</div> {% endfor %}
</div> </div>
</details> </details>
{% endfor %} {% endfor %}
@ -151,11 +209,12 @@
<script> <script>
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
const adapters = ['nws', 'firms', 'usgs_quake']; // Find all region pickers dynamically
const regionPickers = document.querySelectorAll('[id^="region-picker-"]');
adapters.forEach(function(adapterName) { regionPickers.forEach(function(container) {
const container = document.getElementById('region-picker-' + adapterName); const adapterName = container.dataset.adapter;
if (!container) return; const fieldName = container.dataset.field || 'region';
const savedNorth = parseFloat(container.dataset.north); const savedNorth = parseFloat(container.dataset.north);
const savedSouth = parseFloat(container.dataset.south); const savedSouth = parseFloat(container.dataset.south);
@ -215,10 +274,11 @@ document.addEventListener('DOMContentLoaded', function() {
rectangle.editing.enable(); rectangle.editing.enable();
const northInput = container.querySelector('input[name="' + adapterName + '_region_north"]'); const inputPrefix = adapterName + '_' + fieldName;
const southInput = container.querySelector('input[name="' + adapterName + '_region_south"]'); const northInput = container.querySelector('input[name="' + inputPrefix + '_north"]');
const eastInput = container.querySelector('input[name="' + adapterName + '_region_east"]'); const southInput = container.querySelector('input[name="' + inputPrefix + '_south"]');
const westInput = container.querySelector('input[name="' + adapterName + '_region_west"]'); const eastInput = container.querySelector('input[name="' + inputPrefix + '_east"]');
const westInput = container.querySelector('input[name="' + inputPrefix + '_west"]');
function updateInputs() { function updateInputs() {
const b = rectangle.getBounds(); const b = rectangle.getBounds();

View file

@ -417,3 +417,47 @@ class TestAdaptersJsonbRegression:
assert isinstance(captured_audit["after"], dict), f"after should be dict, got {type(captured_audit['after'])}" assert isinstance(captured_audit["after"], dict), f"after should be dict, got {type(captured_audit['after'])}"
assert isinstance(captured_audit["before"]["settings"], dict), "before.settings should be dict" assert isinstance(captured_audit["before"]["settings"], dict), "before.settings should be dict"
assert isinstance(captured_audit["after"]["settings"], dict), "after.settings should be dict" assert isinstance(captured_audit["after"]["settings"], dict), "after.settings should be dict"
@pytest.mark.asyncio
async def test_adapters_edit_fetches_api_keys_into_context(self):
"""GET /adapters/firms includes api_keys from database in context."""
from central.gui.routes import adapters_edit_form
mock_request = MagicMock()
mock_request.state.operator = MagicMock(id=1, username="testop")
mock_request.state.csrf_token = "test_csrf_token"
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(side_effect=[
# Adapter row
{"name": "firms", "enabled": True, "cadence_s": 300, "settings": {},
"paused_at": None, "updated_at": None, "last_error": None},
# System row
{"map_tile_url": "https://tile.example.com", "map_attribution": "Test"},
])
mock_conn.fetch = AsyncMock(return_value=[
{"alias": "firms_key"},
{"alias": "other_key"},
])
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
result = await adapters_edit_form(mock_request, "firms")
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert "api_keys" in context
assert len(context["api_keys"]) == 2
assert context["api_keys"][0]["alias"] == "firms_key"
assert context["api_keys"][1]["alias"] == "other_key"

View file

@ -4,7 +4,7 @@ import pytest
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from central.gui.form_descriptors import describe_fields, FieldDescriptor, _type_to_widget from central.gui.form_descriptors import describe_fields, FieldDescriptor, _type_to_widget_and_options
from central.config_models import RegionConfig from central.config_models import RegionConfig
@ -33,39 +33,39 @@ class SettingsWithRegion(BaseModel):
class TestTypeToWidget: class TestTypeToWidget:
"""Tests for _type_to_widget function.""" """Tests for _type_to_widget_and_options function."""
def test_str_maps_to_text(self): def test_str_maps_to_text(self):
assert _type_to_widget("field", str) == "text" assert _type_to_widget_and_options("field", str) == ("text", None)
def test_int_maps_to_number(self): def test_int_maps_to_number(self):
assert _type_to_widget("field", int) == "number" assert _type_to_widget_and_options("field", int) == ("number", None)
def test_bool_maps_to_checkbox(self): def test_bool_maps_to_checkbox(self):
assert _type_to_widget("field", bool) == "checkbox" assert _type_to_widget_and_options("field", bool) == ("checkbox", None)
def test_list_str_maps_to_csv(self): def test_list_str_maps_to_csv(self):
assert _type_to_widget("field", list[str]) == "csv" assert _type_to_widget_and_options("field", list[str]) == ("csv", None)
def test_region_config_maps_to_region(self): def test_region_config_maps_to_region(self):
assert _type_to_widget("field", RegionConfig) == "region" assert _type_to_widget_and_options("field", RegionConfig) == ("region", None)
def test_optional_region_maps_to_region(self): def test_optional_region_maps_to_region(self):
assert _type_to_widget("field", Optional[RegionConfig]) == "region" assert _type_to_widget_and_options("field", Optional[RegionConfig]) == ("region", None)
def test_optional_str_maps_to_text(self): def test_optional_str_maps_to_text(self):
"""Optional[str] should map to text widget.""" """Optional[str] should map to text widget."""
assert _type_to_widget("field", Optional[str]) == "text" assert _type_to_widget_and_options("field", Optional[str]) == ("text", None)
def test_optional_int_maps_to_number(self): def test_optional_int_maps_to_number(self):
"""Optional[int] should map to number widget.""" """Optional[int] should map to number widget."""
assert _type_to_widget("field", Optional[int]) == "number" assert _type_to_widget_and_options("field", Optional[int]) == ("number", None)
def test_unsupported_type_raises(self): def test_unsupported_type_raises(self):
class CustomType: class CustomType:
pass pass
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_type_to_widget("field", CustomType) _type_to_widget_and_options("field", CustomType)
class TestDescribeFields: class TestDescribeFields:
@ -172,15 +172,17 @@ class TestRealAdapterSchemas:
fields = describe_fields(FIRMSSettings, { fields = describe_fields(FIRMSSettings, {
"api_key_alias": "firms_key", "api_key_alias": "firms_key",
"satellites": ["VIIRS_SNPP"] "satellites": ["VIIRS_SNPP_NRT"]
}) })
key_field = next(f for f in fields if f.name == "api_key_alias") key_field = next(f for f in fields if f.name == "api_key_alias")
assert key_field.widget == "text" assert key_field.widget == "text"
sat_field = next(f for f in fields if f.name == "satellites") sat_field = next(f for f in fields if f.name == "satellites")
assert sat_field.widget == "csv" assert sat_field.widget == "checkboxes"
assert sat_field.current_value == ["VIIRS_SNPP"] assert sat_field.current_value == ["VIIRS_SNPP_NRT"]
assert sat_field.options is not None
assert "VIIRS_SNPP_NRT" in sat_field.options
def test_usgs_quake_settings(self): def test_usgs_quake_settings(self):
"""USGSQuakeSettings generates correct field descriptors.""" """USGSQuakeSettings generates correct field descriptors."""
@ -189,8 +191,11 @@ class TestRealAdapterSchemas:
fields = describe_fields(USGSQuakeSettings, {"feed": "all_hour"}) fields = describe_fields(USGSQuakeSettings, {"feed": "all_hour"})
feed_field = next(f for f in fields if f.name == "feed") feed_field = next(f for f in fields if f.name == "feed")
assert feed_field.widget == "text" assert feed_field.widget == "select"
assert feed_field.current_value == "all_hour" assert feed_field.current_value == "all_hour"
assert feed_field.options is not None
assert "all_hour" in feed_field.options
assert "all_day" in feed_field.options
def test_all_adapters_have_region_field(self): def test_all_adapters_have_region_field(self):
"""All adapter settings schemas include region field.""" """All adapter settings schemas include region field."""
@ -203,3 +208,31 @@ class TestRealAdapterSchemas:
region_field = next((f for f in fields if f.name == "region"), None) region_field = next((f for f in fields if f.name == "region"), None)
assert region_field is not None, f"{schema.__name__} should have region field" assert region_field is not None, f"{schema.__name__} should have region field"
assert region_field.widget == "region" assert region_field.widget == "region"
class TestLiteralTypes:
"""Tests for Literal type support."""
def test_literal_maps_to_select(self):
"""Literal type maps to select widget with options."""
from typing import Literal
widget, options = _type_to_widget_and_options("field", Literal["a", "b", "c"])
assert widget == "select"
assert options == ["a", "b", "c"]
def test_list_literal_maps_to_checkboxes(self):
"""list[Literal] maps to checkboxes widget with options."""
from typing import Literal
widget, options = _type_to_widget_and_options("field", list[Literal["x", "y", "z"]])
assert widget == "checkboxes"
assert options == ["x", "y", "z"]
def test_optional_literal_maps_to_select(self):
"""Optional[Literal] maps to select widget."""
from typing import Literal, Optional
widget, options = _type_to_widget_and_options("field", Optional[Literal["one", "two"]])
assert widget == "select"
assert options == ["one", "two"]

View file

@ -199,3 +199,357 @@ class TestSetupGateMiddlewareWizard:
response = client.get("/setup/operator") response = client.get("/setup/operator")
assert response.status_code == 302 assert response.status_code == 302
assert response.headers["location"] == "/" assert response.headers["location"] == "/"
class TestSetupAdaptersErrorRerender:
"""Test wizard adapters form error re-render path."""
@pytest.mark.asyncio
async def test_invalid_cadence_rerenders_with_error(self):
"""POST /setup/adapters with cadence_s=5 re-renders form with error, no DB write."""
from central.gui.routes import setup_adapters_submit
mock_request = MagicMock()
mock_request.cookies = {}
mock_request.state = MagicMock()
# Mock form data with invalid cadence
mock_form = MagicMock()
mock_form.get.side_effect = lambda k, d="": {
"csrf_token": "test_csrf_token",
"nws_enabled": "on",
"nws_cadence_s": "5", # Invalid: below ge=10
"nws_contact_email": "test@example.com",
"nws_region_north": "49.0",
"nws_region_south": "31.0",
"nws_region_east": "-102.0",
"nws_region_west": "-124.0",
"firms_cadence_s": "300",
"firms_region_north": "49.0",
"firms_region_south": "31.0",
"firms_region_east": "-102.0",
"firms_region_west": "-124.0",
"usgs_quake_cadence_s": "300",
"usgs_quake_feed": "all_hour",
"usgs_quake_region_north": "49.0",
"usgs_quake_region_south": "31.0",
"usgs_quake_region_east": "-102.0",
"usgs_quake_region_west": "-124.0",
}.get(k, d)
mock_form.getlist.side_effect = lambda k: {
"firms_satellites": ["VIIRS_SNPP_NRT"],
}.get(k, [])
mock_form.__contains__ = lambda self, k: k in ["nws_enabled"]
mock_request.form = AsyncMock(return_value=mock_form)
# Mock wizard state
mock_state = MagicMock()
mock_state.operator = {"username": "test", "password_hash": "hash"}
mock_state.api_keys = []
mock_state.adapters = None
mock_state.system = None
# Mock pool with no actual DB access (should not be called for writes)
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetch = AsyncMock(return_value=[
{"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}},
])
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.wizard.get_wizard_state", return_value=mock_state):
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)):
result = await setup_adapters_submit(mock_request)
# Should return 200 (re-render), not 302 (redirect)
assert result.status_code == 200
# Check that template was called with errors
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["error"] == "Please fix the errors below."
assert "errors" in context
assert context["errors"] is not None
assert "nws_cadence_s" in context["errors"]
assert "10" in context["errors"]["nws_cadence_s"] # Should mention min value
# Verify adapters have correct shape (with fields)
assert "adapters" in context
for adapter in context["adapters"]:
assert "name" in adapter
assert "display_name" in adapter
assert "enabled" in adapter
assert "cadence_s" in adapter
assert "settings" in adapter
assert "fields" in adapter
# Verify no DB execute was called (no writes)
mock_conn.execute.assert_not_called()
@pytest.mark.asyncio
async def test_invalid_region_bounds_shows_pydantic_error(self):
"""POST /setup/adapters with inverted region bounds shows RegionConfig error."""
from central.gui.routes import setup_adapters_submit
mock_request = MagicMock()
mock_request.cookies = {}
mock_request.state = MagicMock()
# Mock form data with inverted region (south > north)
mock_form = MagicMock()
mock_form.get.side_effect = lambda k, d="": {
"csrf_token": "test_csrf_token",
"nws_cadence_s": "300",
"nws_contact_email": "test@example.com",
"nws_region_north": "10.0", # Invalid: north < south
"nws_region_south": "20.0",
"nws_region_east": "-102.0",
"nws_region_west": "-124.0",
"firms_cadence_s": "300",
"firms_region_north": "49.0",
"firms_region_south": "31.0",
"firms_region_east": "-102.0",
"firms_region_west": "-124.0",
"usgs_quake_cadence_s": "300",
"usgs_quake_feed": "all_hour",
"usgs_quake_region_north": "49.0",
"usgs_quake_region_south": "31.0",
"usgs_quake_region_east": "-102.0",
"usgs_quake_region_west": "-124.0",
}.get(k, d)
mock_form.getlist.side_effect = lambda k: {
"firms_satellites": ["VIIRS_SNPP_NRT"],
}.get(k, [])
mock_form.__contains__ = lambda self, k: False
mock_request.form = AsyncMock(return_value=mock_form)
mock_state = MagicMock()
mock_state.operator = {"username": "test", "password_hash": "hash"}
mock_state.api_keys = []
mock_state.adapters = None
mock_state.system = None
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetch = AsyncMock(return_value=[
{"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}},
])
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.wizard.get_wizard_state", return_value=mock_state):
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)):
result = await setup_adapters_submit(mock_request)
assert result.status_code == 200
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["errors"] is not None
assert "nws_region" in context["errors"]
# Error should come from RegionConfig validator, mentioning bounds
assert "north" in context["errors"]["nws_region"].lower() or "south" in context["errors"]["nws_region"].lower()
@pytest.mark.asyncio
async def test_invalid_contact_email_via_pydantic_pattern(self):
"""POST /setup/adapters with NWS contact_email='not-an-email' shows Pydantic pattern error."""
from central.gui.routes import setup_adapters_submit
mock_request = MagicMock()
mock_request.cookies = {}
mock_request.state = MagicMock()
mock_form = MagicMock()
mock_form.get.side_effect = lambda k, d="": {
"csrf_token": "test_csrf_token",
"nws_enabled": "on",
"nws_cadence_s": "300",
"nws_contact_email": "not-an-email", # Invalid email format
"nws_region_north": "49.0",
"nws_region_south": "31.0",
"nws_region_east": "-102.0",
"nws_region_west": "-124.0",
"firms_cadence_s": "300",
"firms_region_north": "49.0",
"firms_region_south": "31.0",
"firms_region_east": "-102.0",
"firms_region_west": "-124.0",
"usgs_quake_cadence_s": "300",
"usgs_quake_feed": "all_hour",
"usgs_quake_region_north": "49.0",
"usgs_quake_region_south": "31.0",
"usgs_quake_region_east": "-102.0",
"usgs_quake_region_west": "-124.0",
}.get(k, d)
mock_form.getlist.side_effect = lambda k: {
"firms_satellites": ["VIIRS_SNPP_NRT"],
}.get(k, [])
mock_form.__contains__ = lambda self, k: k in ["nws_enabled"]
mock_request.form = AsyncMock(return_value=mock_form)
mock_state = MagicMock()
mock_state.operator = {"username": "test", "password_hash": "hash"}
mock_state.api_keys = []
mock_state.adapters = None
mock_state.system = None
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetch = AsyncMock(return_value=[
{"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}},
])
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.wizard.get_wizard_state", return_value=mock_state):
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)):
result = await setup_adapters_submit(mock_request)
assert result.status_code == 200
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["errors"] is not None
assert "nws_contact_email" in context["errors"]
# Error should be from Pydantic pattern validation
error_msg = context["errors"]["nws_contact_email"].lower()
assert "pattern" in error_msg or "string" in error_msg or "match" in error_msg
@pytest.mark.asyncio
async def test_invalid_api_key_alias_generic(self):
"""POST /setup/adapters with FIRMS api_key_alias='bogus' shows generic error."""
from central.gui.routes import setup_adapters_submit
mock_request = MagicMock()
mock_request.cookies = {}
mock_request.state = MagicMock()
mock_form = MagicMock()
mock_form.get.side_effect = lambda k, d="": {
"csrf_token": "test_csrf_token",
"nws_cadence_s": "300",
"nws_contact_email": "test@example.com",
"nws_region_north": "49.0",
"nws_region_south": "31.0",
"nws_region_east": "-102.0",
"nws_region_west": "-124.0",
"firms_cadence_s": "300",
"firms_api_key_alias": "bogus-alias-not-in-state", # Invalid alias
"firms_region_north": "49.0",
"firms_region_south": "31.0",
"firms_region_east": "-102.0",
"firms_region_west": "-124.0",
"usgs_quake_cadence_s": "300",
"usgs_quake_feed": "all_hour",
"usgs_quake_region_north": "49.0",
"usgs_quake_region_south": "31.0",
"usgs_quake_region_east": "-102.0",
"usgs_quake_region_west": "-124.0",
}.get(k, d)
mock_form.getlist.side_effect = lambda k: {
"firms_satellites": ["VIIRS_SNPP_NRT"],
}.get(k, [])
mock_form.__contains__ = lambda self, k: False
mock_request.form = AsyncMock(return_value=mock_form)
mock_state = MagicMock()
mock_state.operator = {"username": "test", "password_hash": "hash"}
mock_state.api_keys = [{"alias": "valid_key"}] # Only valid_key exists
mock_state.adapters = None
mock_state.system = None
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetch = AsyncMock(return_value=[
{"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}},
{"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}},
])
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.routes.get_settings") as mock_settings:
mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab"
with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True):
with patch("central.gui.wizard.get_wizard_state", return_value=mock_state):
with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)):
result = await setup_adapters_submit(mock_request)
assert result.status_code == 200
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["errors"] is not None
assert "firms_api_key_alias" in context["errors"]
assert "API key alias does not exist" in context["errors"]["firms_api_key_alias"]
@pytest.mark.asyncio
async def test_api_key_field_none_no_check(self):
"""Adapters with api_key_field=None do not trigger the api_key check."""
# Verify that NWSAdapter has api_key_field=None
from central.adapters.nws import NWSAdapter
from central.adapters.firms import FIRMSAdapter
from central.adapters.usgs_quake import USGSQuakeAdapter
# NWS and USGS should have api_key_field=None
assert NWSAdapter.api_key_field is None
assert USGSQuakeAdapter.api_key_field is None
# FIRMS should have api_key_field set
assert FIRMSAdapter.api_key_field == "api_key_alias"