From 08eb729979b8771128d5600cad6cb30f741ec58d Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Tue, 19 May 2026 00:38:06 +0000 Subject: [PATCH 1/3] refactor(wizard): generic adapter handling with Literal types - Add Literal type support to form_descriptors.py - Literal fields map to select widget - list[Literal] fields map to checkboxes widget - Options list extracted from Literal type args - Update FIRMS adapter: satellites is now list[Literal[...]] - Update USGS adapter: feed is now Literal[...] - Refactor wizard to use wizard_order for adapter filtering - Replace hardcoded adapter lists with dynamic discovery - Remove _get_valid_satellites() and _get_valid_feeds() helpers - Generic field parsing using describe_fields() pattern - Update templates for generic widget rendering - Add select/checkboxes widgets to adapters_edit.html - Update tests for new widget types Co-Authored-By: Claude Opus 4.5 --- src/central/adapters/firms.py | 4 +- src/central/adapters/usgs_quake.py | 4 +- src/central/gui/__init__.py | 29 +- src/central/gui/form_descriptors.py | 62 +++-- src/central/gui/routes.py | 259 +++++++++++------- src/central/gui/templates/adapters_edit.html | 34 +++ src/central/gui/templates/setup_adapters.html | 238 ++++++++++------ tests/test_form_descriptors.py | 63 ++++- 8 files changed, 470 insertions(+), 223 deletions(-) diff --git a/src/central/adapters/firms.py b/src/central/adapters/firms.py index 7538d96..0b1647b 100644 --- a/src/central/adapters/firms.py +++ b/src/central/adapters/firms.py @@ -7,7 +7,7 @@ from collections.abc import AsyncIterator from datetime import datetime, timezone from io import StringIO from pathlib import Path -from typing import Any +from typing import Any, Literal import aiohttp from tenacity import ( @@ -54,7 +54,7 @@ SEVERITY_MAP = { class FIRMSSettings(BaseModel): """Settings schema for FIRMS adapter.""" api_key_alias: str = "firms" - satellites: list[str] = ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"] + satellites: list[Literal["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT", "VIIRS_NOAA21_NRT"]] = ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"] region: RegionConfig | None = None diff --git a/src/central/adapters/usgs_quake.py b/src/central/adapters/usgs_quake.py index e73148f..63009ee 100644 --- a/src/central/adapters/usgs_quake.py +++ b/src/central/adapters/usgs_quake.py @@ -5,7 +5,7 @@ import sqlite3 from collections.abc import AsyncIterator from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Any, Literal import aiohttp from shapely.geometry import Point, box as shapely_box @@ -64,7 +64,7 @@ def magnitude_to_severity(mag: float) -> int: class USGSQuakeSettings(BaseModel): """Settings schema for USGS quake adapter.""" - feed: str = "all_hour" + feed: Literal["all_hour", "all_day", "all_week", "all_month"] = "all_hour" region: RegionConfig | None = None diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 71a302b..79703cb 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -247,18 +247,37 @@ def _create_app() -> FastAPI: except Exception: pass - # Import helper functions for valid values - from central.gui.routes import _get_valid_satellites, _get_valid_feeds + # Add field descriptors to adapters + from central.gui.routes import _adapter_classes + from central.gui.form_descriptors import describe_fields + adapter_classes = _adapter_classes() + wizard_adapters = sorted( + [(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None], + key=lambda nc: nc[1].wizard_order + ) + # Rebuild adapters with fields + enriched_adapters = [] + for name, cls in wizard_adapters: + adapter_data = next((a for a in adapters if a["name"] == name), None) + if adapter_data: + settings_dict = adapter_data.get("settings", {}) + fields = describe_fields(cls.settings_schema, settings_dict) + enriched_adapters.append({ + "name": name, + "display_name": cls.display_name, + "enabled": adapter_data.get("enabled", False), + "cadence_s": adapter_data.get("cadence_s", 300), + "settings": settings_dict, + "fields": fields, + }) response = templates.TemplateResponse( request=request, name="setup_adapters.html", context={ "csrf_token": csrf_token, - "adapters": adapters, + "adapters": enriched_adapters, "api_keys": api_keys, - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, "tile_attribution": tile_attribution, "error": error_msg, diff --git a/src/central/gui/form_descriptors.py b/src/central/gui/form_descriptors.py index 2f1f14d..ef7588e 100644 --- a/src/central/gui/form_descriptors.py +++ b/src/central/gui/form_descriptors.py @@ -4,8 +4,8 @@ If a second nested settings type beyond RegionConfig appears, refactor this helper to recurse over nested models. """ -from dataclasses import dataclass -from typing import Any, Union, get_args, get_origin +from dataclasses import dataclass, field +from typing import Any, Literal, Union, get_args, get_origin from pydantic import BaseModel from pydantic.fields import FieldInfo @@ -19,15 +19,30 @@ class FieldDescriptor: """Describes a form field for rendering.""" name: str label: str - widget: str # "text", "number", "checkbox", "csv", "region" + widget: str # "text", "number", "checkbox", "csv", "select", "checkboxes", "region" current_value: Any default: Any description: str required: bool + options: list[str] | None = None # For select/checkboxes widgets -def _type_to_widget(field_name: str, field_type: type) -> str: - """Map a Python type to a widget type.""" +def _is_literal(tp: type) -> bool: + """Check if a type is a Literal type.""" + return get_origin(tp) is Literal + + +def _get_literal_values(tp: type) -> list[str]: + """Extract the literal values from a Literal type.""" + return list(get_args(tp)) + + +def _type_to_widget_and_options(field_name: str, field_type: type) -> tuple[str, list[str] | None]: + """Map a Python type to a widget type and optional options list. + + Returns: + Tuple of (widget_type, options_list_or_none) + """ # Handle Optional/Union types origin = get_origin(field_type) args = get_args(field_type) @@ -39,24 +54,38 @@ def _type_to_widget(field_name: str, field_type: type) -> str: if non_none_args: inner_type = non_none_args[0] # Recursively determine widget for the inner type - return _type_to_widget(field_name, inner_type) + return _type_to_widget_and_options(field_name, inner_type) + + # Check for Literal type (single select) + if _is_literal(field_type): + options = _get_literal_values(field_type) + return "select", [str(o) for o in options] # Direct type checks if field_type is str: - return "text" + return "text", None if field_type is int: - return "number" + return "number", None if field_type is bool: - return "checkbox" + return "checkbox", None if field_type is RegionConfig: - return "region" + return "region", None - # Check for list[str] + # Check for list types if origin is list: - if args and args[0] is str: - return "csv" + inner_type = args[0] if args else None + + # list[Literal[...]] -> checkboxes + if inner_type is not None and _is_literal(inner_type): + options = _get_literal_values(inner_type) + return "checkboxes", [str(o) for o in options] + + # list[str] -> csv + if inner_type is str: + return "csv", None + raise NotImplementedError( - f"Field '{field_name}' has unsupported list type: list[{args[0].__name__ if args else '?'}]" + f"Field '{field_name}' has unsupported list type: list[{inner_type.__name__ if inner_type else '?'}]" ) # Check if it's a BaseModel subclass (nested model other than RegionConfig) @@ -98,8 +127,8 @@ def describe_fields(model_cls: type[BaseModel], current: dict) -> list[FieldDesc # Get the field type field_type = field_info.annotation - # Determine widget - widget = _type_to_widget(field_name, field_type) + # Determine widget and options + widget, options = _type_to_widget_and_options(field_name, field_type) # Get current value, falling back to default if field_name in current: @@ -128,6 +157,7 @@ def describe_fields(model_cls: type[BaseModel], current: dict) -> list[FieldDesc default=default, description=description, required=required, + options=options, )) return descriptors diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 70e88e9..96c9791 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -73,18 +73,6 @@ ALIAS_REGEX = re.compile(r"^[a-zA-Z0-9_]+$") EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") -def _get_valid_satellites() -> list[str]: - """Get valid satellite identifiers from firms adapter.""" - from central.adapters.firms import SATELLITE_SHORT - return list(SATELLITE_SHORT.keys()) - - -def _get_valid_feeds() -> set[str]: - """Get valid feed values from usgs_quake adapter.""" - from central.adapters.usgs_quake import VALID_FEEDS - return VALID_FEEDS - - def _get_templates(): """Get templates instance (deferred import to avoid circular).""" from central.gui import templates @@ -647,18 +635,31 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: templates = _get_templates() pool = get_pool() + # Get wizard adapters (filtered by wizard_order) + adapter_classes = _adapter_classes() + wizard_adapters = sorted( + [(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None], + key=lambda nc: nc[1].wizard_order + ) + # Pre-fill from cookie state or DB defaults if state.adapters: adapters = [] - for name in ["firms", "nws", "usgs_quake"]: + for name, cls in wizard_adapters: if name in state.adapters: a = state.adapters[name] - adapters.append({ - "name": name, - "enabled": a["enabled"], - "cadence_s": a["cadence_s"], - "settings": a["settings"], - }) + settings_dict = a["settings"] + else: + settings_dict = {} + fields = describe_fields(cls.settings_schema, settings_dict) + adapters.append({ + "name": name, + "display_name": cls.display_name, + "enabled": a["enabled"] if name in state.adapters else False, + "cadence_s": a["cadence_s"] if name in state.adapters else 300, + "settings": settings_dict, + "fields": fields, + }) else: async with pool.acquire() as conn: rows = await conn.fetch( @@ -668,15 +669,28 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: ORDER BY name """ ) - adapters = [] - for row in rows: - settings_data = row["settings"] or {} - adapters.append({ - "name": row["name"], - "enabled": row["enabled"], - "cadence_s": row["cadence_s"], - "settings": settings_data, - }) + db_adapters = {row["name"]: row for row in rows} + + adapters = [] + for name, cls in wizard_adapters: + if name in db_adapters: + row = db_adapters[name] + settings_dict = row["settings"] or {} + enabled = row["enabled"] + cadence_s = row["cadence_s"] + else: + settings_dict = {} + enabled = False + cadence_s = 300 + fields = describe_fields(cls.settings_schema, settings_dict) + adapters.append({ + "name": name, + "display_name": cls.display_name, + "enabled": enabled, + "cadence_s": cadence_s, + "settings": settings_dict, + "fields": fields, + }) # Get API keys from wizard state (not DB) api_keys = [{"alias": k["alias"]} for k in state.api_keys] @@ -701,8 +715,6 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: "csrf_token": csrf_token, "adapters": adapters, "api_keys": api_keys, - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, "tile_attribution": tile_attribution, "error": None, @@ -755,7 +767,14 @@ async def setup_adapters_submit(request: Request) -> Response: "settings": row["settings"] or {}, } - for adapter_name in ["firms", "nws", "usgs_quake"]: + # Get wizard adapters (filtered by wizard_order) + adapter_classes = _adapter_classes() + wizard_adapters = sorted( + [(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None], + key=lambda nc: nc[1].wizard_order + ) + + for adapter_name, adapter_cls in wizard_adapters: current = current_adapters.get(adapter_name, {"enabled": False, "cadence_s": 300, "settings": {}}) current_settings = current.get("settings", {}) new_settings = dict(current_settings) @@ -777,73 +796,101 @@ async def setup_adapters_submit(request: Request) -> Response: errors[f"{adapter_name}_cadence_s"] = "Cadence must be a valid integer" cadence_s = current.get("cadence_s", 300) - # Adapter-specific validation - if adapter_name == "nws": - contact_email = form.get(f"{adapter_name}_contact_email", "").strip() - if enabled: - if not contact_email: - errors[f"{adapter_name}_contact_email"] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(contact_email): - errors[f"{adapter_name}_contact_email"] = "Invalid email format" + # Generic field parsing using describe_fields + fields = describe_fields(adapter_cls.settings_schema, current_settings) + for field in fields: + form_key = f"{adapter_name}_{field.name}" + + if field.widget == "text": + value = form.get(form_key, "").strip() + # Special validation for contact_email + if field.name == "contact_email": + if enabled: + if not value: + errors[form_key] = "Contact email is required when enabled" + elif not EMAIL_REGEX.match(value): + errors[form_key] = "Invalid email format" + else: + new_settings[field.name] = value + else: + new_settings[field.name] = value if value else current_settings.get(field.name) + # Special validation for api_key_alias + elif field.name == "api_key_alias": + if value: + if not any(k["alias"] == value for k in state.api_keys): + errors[form_key] = "API key alias does not exist" + else: + new_settings[field.name] = value + else: + new_settings[field.name] = None else: - new_settings["contact_email"] = contact_email - else: - new_settings["contact_email"] = contact_email if contact_email else current_settings.get("contact_email") + new_settings[field.name] = value if value else current_settings.get(field.name) - elif adapter_name == "firms": - api_key_alias = form.get(f"{adapter_name}_api_key_alias", "").strip() - satellites = form.getlist(f"{adapter_name}_satellites") - - if api_key_alias: - # Validate against wizard state keys - if not any(k["alias"] == api_key_alias for k in state.api_keys): - errors[f"{adapter_name}_api_key_alias"] = f"API key alias does not exist" + elif field.widget == "number": + value_str = form.get(form_key, "").strip() + if value_str: + try: + new_settings[field.name] = int(value_str) + except ValueError: + errors[form_key] = f"{field.label} must be a valid number" else: - new_settings["api_key_alias"] = api_key_alias - else: - new_settings["api_key_alias"] = None + new_settings[field.name] = current_settings.get(field.name) - # Validate satellites - valid_sats = set(_get_valid_satellites()) - invalid_sats = [s for s in satellites if s not in valid_sats] - if invalid_sats: - errors[f"{adapter_name}_satellites"] = f"Invalid satellites: " + ", ".join(invalid_sats) - else: - new_settings["satellites"] = satellites + elif field.widget == "checkbox": + new_settings[field.name] = form_key in form - elif adapter_name == "usgs_quake": - feed = form.get(f"{adapter_name}_feed", "").strip() - valid_feeds = _get_valid_feeds() - if feed not in valid_feeds: - errors[f"{adapter_name}_feed"] = "Invalid feed" - else: - new_settings["feed"] = feed + elif field.widget == "csv": + value = form.get(form_key, "").strip() + if value: + new_settings[field.name] = [v.strip() for v in value.split(",") if v.strip()] + else: + new_settings[field.name] = [] - # Region validation (all adapters) - region_north_str = form.get(f"{adapter_name}_region_north", "").strip() - region_south_str = form.get(f"{adapter_name}_region_south", "").strip() - region_east_str = form.get(f"{adapter_name}_region_east", "").strip() - region_west_str = form.get(f"{adapter_name}_region_west", "").strip() + elif field.widget == "select": + value = form.get(form_key, "").strip() + if value and field.options and value not in field.options: + errors[form_key] = f"Invalid {field.label.lower()}" + else: + new_settings[field.name] = value - try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) + elif field.widget == "checkboxes": + # Use getlist for checkbox groups - absence means empty list + values = form.getlist(form_key) + if field.options: + invalid = [v for v in values if v not in field.options] + if invalid: + errors[form_key] = f"Invalid values: {', '.join(invalid)}" + else: + new_settings[field.name] = values + else: + new_settings[field.name] = values - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_region"] = "Invalid latitude: south < north, both -90 to 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_region"] = "Invalid longitude: west < east, both -180 to 180" - else: - new_settings["region"] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_region"] = "Region coordinates must be valid numbers" + elif field.widget == "region": + # Region validation + region_north_str = form.get(f"{adapter_name}_{field.name}_north", "").strip() + region_south_str = form.get(f"{adapter_name}_{field.name}_south", "").strip() + region_east_str = form.get(f"{adapter_name}_{field.name}_east", "").strip() + region_west_str = form.get(f"{adapter_name}_{field.name}_west", "").strip() + + try: + region_north = float(region_north_str) + region_south = float(region_south_str) + region_east = float(region_east_str) + region_west = float(region_west_str) + + if not (-90 <= region_south < region_north <= 90): + errors[f"{adapter_name}_{field.name}"] = "Invalid latitude: south < north, both -90 to 90" + elif not (-180 <= region_west < region_east <= 180): + errors[f"{adapter_name}_{field.name}"] = "Invalid longitude: west < east, both -180 to 180" + else: + new_settings[field.name] = { + "north": region_north, + "south": region_south, + "east": region_east, + "west": region_west, + } + except ValueError: + errors[f"{adapter_name}_{field.name}"] = "Region coordinates must be valid numbers" new_adapters[adapter_name] = { "enabled": enabled, @@ -918,10 +965,20 @@ async def setup_finish_form(request: Request) -> HTMLResponse: adapters = [] if state.adapters: - for name in ["firms", "nws", "usgs_quake"]: + adapter_classes = _adapter_classes() + wizard_adapters = sorted( + [(name, cls) for name, cls in adapter_classes.items() if cls.wizard_order is not None], + key=lambda nc: nc[1].wizard_order + ) + for name, cls in wizard_adapters: if name in state.adapters: a = state.adapters[name] - adapters.append({"name": name, "enabled": a["enabled"], "cadence_s": a["cadence_s"]}) + adapters.append({ + "name": name, + "display_name": cls.display_name, + "enabled": a["enabled"], + "cadence_s": a["cadence_s"], + }) csrf_token, signed_token = reuse_or_generate_pre_auth_csrf(request, settings.csrf_secret) response = templates.TemplateResponse( @@ -1441,6 +1498,24 @@ async def adapters_edit_submit( parsed_values[field.name] = [v.strip() for v in raw.split(",") if v.strip()] else: parsed_values[field.name] = [] + elif field.widget == "select": + value = raw.strip() if raw else None + if value and field.options and value not in field.options: + errors[field.name] = f"Invalid {field.label.lower()}" + else: + parsed_values[field.name] = value + elif field.widget == "checkboxes": + # Use getlist for checkbox groups + values = form.getlist(field.name) + form_data[field.name] = values # Override raw value + if field.options: + invalid = [v for v in values if v not in field.options] + if invalid: + errors[field.name] = f"Invalid values: {', '.join(invalid)}" + else: + parsed_values[field.name] = values + else: + parsed_values[field.name] = values elif field.widget == "region": # Region handled separately below pass diff --git a/src/central/gui/templates/adapters_edit.html b/src/central/gui/templates/adapters_edit.html index bc6cdfa..4397467 100644 --- a/src/central/gui/templates/adapters_edit.html +++ b/src/central/gui/templates/adapters_edit.html @@ -100,6 +100,40 @@ {% if errors and errors[field.name] %} {{ errors[field.name] }} {% endif %} + + {% elif field.widget == "select" %} + + + {% if field.description %} + {{ field.description }} + {% endif %} + {% if errors and errors[field.name] %} + {{ errors[field.name] }} + {% endif %} + + {% elif field.widget == "checkboxes" %} + + {% set current_values = form_data.getlist(field.name) if form_data and form_data.getlist else (field.current_value or []) %} + {% for opt in field.options %} + + {% endfor %} + {% if field.description %} + {{ field.description }} + {% endif %} + {% if errors and errors[field.name] %} + {{ errors[field.name] }} + {% endif %} {% endif %} {% endfor %} diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html index f80c6e2..c92d0c1 100644 --- a/src/central/gui/templates/setup_adapters.html +++ b/src/central/gui/templates/setup_adapters.html @@ -29,7 +29,7 @@ {% for adapter in adapters %}
- {{ adapter.name }} + {{ adapter.display_name or adapter.name }}
{% endfor %} @@ -151,11 +205,12 @@