diff --git a/lib/query_router.py b/lib/query_router.py new file mode 100644 index 0000000..dda14a2 --- /dev/null +++ b/lib/query_router.py @@ -0,0 +1,161 @@ +"""Semantic query router for Aurora. + +Classifies user queries into routes (nav_route, nav_reverse_geocode, +direct_answer, rag_search) by comparing query embeddings against +pre-computed route centroids from example queries. + +TEI endpoint: http://100.64.0.14:8090/embed (cortex via Tailscale) +""" + +import math +import threading +import requests + +# ── Route examples ──────────────────────────────────────────────────────────── +ROUTE_EXAMPLES = { + "nav_route": [ + "how do I get to Boise", + "directions to Twin Falls", + "how do I get from Buhl to Boise", + "drive from Jerome to Sun Valley", + "route from Boise to McCall", + "what's the fastest way to Sun Valley", + "how far is it to Twin Falls", + "take me to Shoshone", + "navigate to the airport", + "how do I drive to Salt Lake City", + "walking directions to the park", + "bike route to downtown", + ], + "nav_reverse_geocode": [ + "what town is at 42.5, -114.7", + "where am I right now", + "what is at coordinates 43.6, -116.2", + "what location is 42.574, -114.607", + "where is this place 44.0, -114.3", + "what city is near 42.7, -114.5", + "reverse geocode 43.0, -115.0", + "what's at this location 42.9, -114.8", + ], + "direct_answer": [ + "hello", + "hey aurora", + "good morning", + "thanks", + "thank you", + "what's your name", + "who are you", + "tell me a joke", + "how are you", + "hi there", + ], + "rag_search": [ + "what does the survival manual say about water", + "how to purify water in the field", + "how to treat a gunshot wound", + "what is the ranger handbook chapter on patrolling", + "field manual water purification", + "how to build a shelter in the wilderness", + "tactical combat casualty care procedures", + "what does FM 21-76 say about fire starting", + ], +} + +# ── Module-level cache ──────────────────────────────────────────────────────── +_ROUTE_CENTROIDS: dict | None = None +_LOCK = threading.Lock() + + +def _embed_batch(texts: list[str], tei_url: str) -> list[list[float]]: + """Embed a batch of texts via TEI.""" + resp = requests.post(tei_url, json={"inputs": texts}, timeout=30) + resp.raise_for_status() + return resp.json() + + +def _compute_centroid(vectors: list[list[float]]) -> list[float]: + """Element-wise mean of vectors.""" + n = len(vectors) + dim = len(vectors[0]) + centroid = [0.0] * dim + for vec in vectors: + for i in range(dim): + centroid[i] += vec[i] + for i in range(dim): + centroid[i] /= n + return centroid + + +def _cosine_similarity(a: list[float], b: list[float]) -> float: + """Cosine similarity between two vectors (pure Python).""" + dot = 0.0 + norm_a = 0.0 + norm_b = 0.0 + for i in range(len(a)): + dot += a[i] * b[i] + norm_a += a[i] * a[i] + norm_b += b[i] * b[i] + denom = math.sqrt(norm_a) * math.sqrt(norm_b) + if denom == 0: + return 0.0 + return dot / denom + + +def _ensure_centroids(tei_url: str) -> dict[str, list[float]]: + """Lazy-init: embed all examples in one batch, compute centroids, cache.""" + global _ROUTE_CENTROIDS + if _ROUTE_CENTROIDS is not None: + return _ROUTE_CENTROIDS + + with _LOCK: + if _ROUTE_CENTROIDS is not None: + return _ROUTE_CENTROIDS + + # Flatten all examples into one batch + all_texts = [] + route_ranges: dict[str, tuple[int, int]] = {} + offset = 0 + for route, examples in ROUTE_EXAMPLES.items(): + route_ranges[route] = (offset, offset + len(examples)) + all_texts.extend(examples) + offset += len(examples) + + all_vectors = _embed_batch(all_texts, tei_url) + + centroids = {} + for route, (start, end) in route_ranges.items(): + centroids[route] = _compute_centroid(all_vectors[start:end]) + + _ROUTE_CENTROIDS = centroids + return _ROUTE_CENTROIDS + + +def classify( + query: str, + tei_url: str = "http://100.64.0.14:8090/embed", + threshold: float = 0.45, +) -> tuple[str, float]: + """Classify a query into a route. + + Returns (route_name, confidence). If no route exceeds the threshold, + returns ("rag_search", best_score) as the safe default. + """ + centroids = _ensure_centroids(tei_url) + + # Embed the query + vecs = _embed_batch([query], tei_url) + query_vec = vecs[0] + + # Compare against all centroids + best_route = "rag_search" + best_score = 0.0 + for route, centroid in centroids.items(): + sim = _cosine_similarity(query_vec, centroid) + if sim > best_score: + best_score = sim + best_route = route + + if best_score < threshold: + return ("rag_search", best_score) + + return (best_route, best_score) diff --git a/lib/query_router_test.py b/lib/query_router_test.py new file mode 100644 index 0000000..27ccefd --- /dev/null +++ b/lib/query_router_test.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +"""Test suite for the semantic query router.""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from lib.query_router import classify + +TEST_QUERIES = [ + ("how do I get from Buhl to Boise", "nav_route"), + ("what does the survival manual say about water", "rag_search"), + ("what town is at 42.5, -114.7", "nav_reverse_geocode"), + ("hey aurora", "direct_answer"), + ("what's the fastest way to Sun Valley", "nav_route"), + ("how to purify water in the field", "rag_search"), + ("good morning", "direct_answer"), +] + + +def main(): + print("Query Router Test Suite") + print("=" * 70) + + passed = 0 + failed = 0 + + for query, expected in TEST_QUERIES: + route, confidence = classify(query) + status = "PASS" if route == expected else "FAIL" + if status == "PASS": + passed += 1 + else: + failed += 1 + print(f" [{status}] {query!r}") + print(f" → {route} ({confidence:.3f}) expected={expected}") + + print("=" * 70) + print(f"Results: {passed}/{passed + failed} passed") + if failed: + print(f" {failed} FAILED") + sys.exit(1) + else: + print(" All tests passed!") + + +if __name__ == "__main__": + main()