mirror of
https://github.com/zvx-echo6/recon.git
synced 2026-05-20 22:54:46 +02:00
469 lines
18 KiB
Python
469 lines
18 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
migrate_domains.py — Reclassify 5 legacy domains via Gemini Flash.
|
||
|
|
|
||
|
|
Targets: Sustainment Systems, Off-Grid Systems, Defense & Tactics,
|
||
|
|
Community Coordination, Leadership
|
||
|
|
|
||
|
|
Maps each to one of the 18 approved domains. 16 parallel workers,
|
||
|
|
checkpoint file, crash-safe, incremental saves, progress every 5,000.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python3 /tmp/migrate_domains.py [--dry-run] [--workers 16] [--limit N]
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
import random
|
||
|
|
import logging
|
||
|
|
import argparse
|
||
|
|
import threading
|
||
|
|
from pathlib import Path
|
||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
|
from collections import defaultdict
|
||
|
|
|
||
|
|
import google.generativeai as genai
|
||
|
|
from qdrant_client import QdrantClient
|
||
|
|
from qdrant_client.models import FieldCondition, MatchValue, Filter
|
||
|
|
|
||
|
|
# Suppress noisy HTTP logs
|
||
|
|
import logging as _logging
|
||
|
|
_logging.getLogger("httpx").setLevel(_logging.WARNING)
|
||
|
|
_logging.getLogger("qdrant_client").setLevel(_logging.WARNING)
|
||
|
|
|
||
|
|
LOG_FILE = Path("/opt/recon/logs/migrate_domains.log")
|
||
|
|
CHECKPOINT_FILE = Path("/opt/recon/data/migrate_domains_checkpoint.json")
|
||
|
|
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
||
|
|
handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler()]
|
||
|
|
)
|
||
|
|
log = logging.getLogger("migrate_domains")
|
||
|
|
|
||
|
|
# ── Constants ───────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
VALID_DOMAINS = {
|
||
|
|
'Agriculture & Livestock', 'Civil Organization', 'Communications',
|
||
|
|
'Food Systems', 'Foundational Skills', 'Logistics', 'Medical',
|
||
|
|
'Navigation', 'Operations', 'Power Systems', 'Preservation & Storage',
|
||
|
|
'Security', 'Shelter & Construction', 'Technology', 'Tools & Equipment',
|
||
|
|
'Vehicles', 'Water Systems', 'Wilderness Skills',
|
||
|
|
}
|
||
|
|
|
||
|
|
SOURCE_DOMAINS = {
|
||
|
|
'Sustainment Systems', 'Off-Grid Systems', 'Defense & Tactics',
|
||
|
|
'Community Coordination', 'Leadership',
|
||
|
|
}
|
||
|
|
|
||
|
|
DOMAIN_LIST_STR = ', '.join(sorted(VALID_DOMAINS))
|
||
|
|
|
||
|
|
CLASSIFY_PROMPT = """\
|
||
|
|
Classify this knowledge concept into exactly one domain from this list:
|
||
|
|
Agriculture & Livestock, Civil Organization, Communications, Food Systems, Foundational Skills, Logistics, Medical, Navigation, Operations, Power Systems, Preservation & Storage, Security, Shelter & Construction, Technology, Tools & Equipment, Vehicles, Water Systems, Wilderness Skills
|
||
|
|
|
||
|
|
Return ONLY the exact domain string, nothing else. No explanation, no punctuation, no quotes.
|
||
|
|
|
||
|
|
Content: {content}
|
||
|
|
Summary: {summary}
|
||
|
|
Subdomain: {subdomain}
|
||
|
|
"""
|
||
|
|
|
||
|
|
DOMAIN_FALLBACK = 'Foundational Skills'
|
||
|
|
|
||
|
|
# ── Key management ──────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def load_gemini_keys():
|
||
|
|
keys = []
|
||
|
|
env_path = Path("/opt/recon/.env")
|
||
|
|
if not env_path.exists():
|
||
|
|
raise FileNotFoundError(f"{env_path} not found")
|
||
|
|
for line in env_path.read_text().splitlines():
|
||
|
|
if line.startswith("GEMINI_KEY_"):
|
||
|
|
keys.append(line.split("=", 1)[1].strip())
|
||
|
|
if not keys:
|
||
|
|
raise ValueError("No GEMINI_KEY_* found in .env")
|
||
|
|
return keys
|
||
|
|
|
||
|
|
|
||
|
|
class KeyRotator:
|
||
|
|
def __init__(self, keys):
|
||
|
|
self.keys = keys
|
||
|
|
self._i = 0
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
|
||
|
|
def next(self):
|
||
|
|
with self._lock:
|
||
|
|
key = self.keys[self._i % len(self.keys)]
|
||
|
|
self._i += 1
|
||
|
|
return key
|
||
|
|
|
||
|
|
|
||
|
|
# ── Classification ──────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def classify_domain(content, summary, subdomains, key):
|
||
|
|
"""Call Gemini Flash to classify into one of 18 domains."""
|
||
|
|
prompt = CLASSIFY_PROMPT.format(
|
||
|
|
content=str(content)[:400] if content else "(none)",
|
||
|
|
summary=str(summary)[:200] if summary else "(none)",
|
||
|
|
subdomain=", ".join(subdomains[:10]) if subdomains else "(none)",
|
||
|
|
)
|
||
|
|
genai.configure(api_key=key)
|
||
|
|
model = genai.GenerativeModel(
|
||
|
|
"gemini-2.0-flash",
|
||
|
|
generation_config={"response_mime_type": "text/plain"}
|
||
|
|
)
|
||
|
|
|
||
|
|
for retry in range(4):
|
||
|
|
try:
|
||
|
|
resp = model.generate_content(prompt)
|
||
|
|
value = resp.text.strip().strip('"').strip("'").strip()
|
||
|
|
if value in VALID_DOMAINS:
|
||
|
|
return value
|
||
|
|
# Try case-insensitive match
|
||
|
|
for valid in VALID_DOMAINS:
|
||
|
|
if value.lower() == valid.lower():
|
||
|
|
return valid
|
||
|
|
# Partial match — Gemini sometimes returns with trailing period
|
||
|
|
clean = value.rstrip('.')
|
||
|
|
if clean in VALID_DOMAINS:
|
||
|
|
return clean
|
||
|
|
# Invalid — retry with stricter prompt
|
||
|
|
if retry < 3:
|
||
|
|
prompt = (
|
||
|
|
f"Your previous response '{value}' was invalid. "
|
||
|
|
f"You must return ONLY one of these exact strings: {DOMAIN_LIST_STR}\n\n"
|
||
|
|
f"Content: {str(content)[:300]}\n"
|
||
|
|
f"Return ONLY the exact domain string."
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
except Exception as e:
|
||
|
|
err = str(e).lower()
|
||
|
|
if any(s in err for s in ["429", "quota", "rate", "503", "unavailable"]):
|
||
|
|
time.sleep(min(5 * (2 ** retry) + random.uniform(0, 3), 60))
|
||
|
|
else:
|
||
|
|
log.warning(f"Gemini error (attempt {retry+1}): {e}")
|
||
|
|
if retry >= 2:
|
||
|
|
break
|
||
|
|
|
||
|
|
return heuristic_fallback(content, summary, subdomains)
|
||
|
|
|
||
|
|
|
||
|
|
def heuristic_fallback(content, summary, subdomains):
|
||
|
|
"""Last-resort heuristic when Gemini fails or returns invalid."""
|
||
|
|
text = f"{summary or ''} {' '.join(subdomains or [])} {str(content or '')[:200]}".lower()
|
||
|
|
|
||
|
|
mapping = [
|
||
|
|
(["farming", "agriculture", "livestock", "animal husbandry", "poultry",
|
||
|
|
"cattle", "crop", "soil fertility", "irrigation for crops"], "Agriculture & Livestock"),
|
||
|
|
(["foraging", "hunting", "fishing", "bushcraft", "wilderness", "survival skill",
|
||
|
|
"fire starting", "shelter building", "trapping", "tracking"], "Wilderness Skills"),
|
||
|
|
(["food preservation", "canning", "dehydration", "smoking", "pickling",
|
||
|
|
"fermentation", "food storage", "freeze dry"], "Preservation & Storage"),
|
||
|
|
(["cooking", "recipe", "nutrition", "food preparation", "baking",
|
||
|
|
"food production", "meal"], "Food Systems"),
|
||
|
|
(["first aid", "medical", "trauma", "surgery", "anatomy", "pharmacology",
|
||
|
|
"wound", "triage", "diagnosis", "disease", "infection", "veterinary",
|
||
|
|
"herbal medicine", "medicinal plant"], "Medical"),
|
||
|
|
(["radio", "antenna", "ham radio", "communication", "signal",
|
||
|
|
"networking", "meshtastic", "comms"], "Communications"),
|
||
|
|
(["solar", "battery", "generator", "wind turbine", "hydroelectric",
|
||
|
|
"power grid", "inverter", "photovoltaic", "electricity"], "Power Systems"),
|
||
|
|
(["water purification", "water filter", "well", "rainwater",
|
||
|
|
"sanitation", "water treatment", "desalination"], "Water Systems"),
|
||
|
|
(["navigation", "compass", "map reading", "gps", "celestial",
|
||
|
|
"orienteering", "land nav"], "Navigation"),
|
||
|
|
(["security", "opsec", "perimeter", "surveillance", "threat",
|
||
|
|
"intrusion detection", "physical security"], "Security"),
|
||
|
|
(["vehicle", "engine", "motor", "aircraft", "boat", "motorcycle",
|
||
|
|
"truck", "maintenance", "diesel", "transmission"], "Vehicles"),
|
||
|
|
(["tool", "equipment", "wrench", "saw", "drill", "hammer",
|
||
|
|
"hand tool", "power tool", "blade", "sharpening"], "Tools & Equipment"),
|
||
|
|
(["construction", "building", "shelter", "carpentry", "masonry",
|
||
|
|
"roofing", "concrete", "framing", "plumbing"], "Shelter & Construction"),
|
||
|
|
(["electronics", "computer", "software", "circuit", "programming",
|
||
|
|
"technology", "digital", "engineering"], "Technology"),
|
||
|
|
(["supply chain", "logistics", "transport", "distribution",
|
||
|
|
"inventory", "supply", "stockpile"], "Logistics"),
|
||
|
|
(["governance", "civil", "community", "administration", "organization",
|
||
|
|
"council", "democratic", "municipal"], "Civil Organization"),
|
||
|
|
(["tactics", "combat", "military", "mission", "patrol", "ambush",
|
||
|
|
"defensive position", "fire team", "maneuver", "engagement",
|
||
|
|
"search and rescue", "sar", "reconnaissance"], "Operations"),
|
||
|
|
]
|
||
|
|
|
||
|
|
for keywords, domain in mapping:
|
||
|
|
if any(kw in text for kw in keywords):
|
||
|
|
return domain
|
||
|
|
|
||
|
|
return DOMAIN_FALLBACK
|
||
|
|
|
||
|
|
|
||
|
|
# ── Checkpoint ──────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class Checkpoint:
|
||
|
|
"""Thread-safe checkpoint tracker for crash recovery."""
|
||
|
|
def __init__(self, path):
|
||
|
|
self.path = path
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
self._completed = set()
|
||
|
|
self._dirty = 0
|
||
|
|
self._load()
|
||
|
|
|
||
|
|
def _load(self):
|
||
|
|
if self.path.exists():
|
||
|
|
try:
|
||
|
|
data = json.loads(self.path.read_text())
|
||
|
|
self._completed = set(data.get("completed", []))
|
||
|
|
log.info(f"Loaded checkpoint: {len(self._completed):,} completed points")
|
||
|
|
except Exception:
|
||
|
|
self._completed = set()
|
||
|
|
|
||
|
|
def is_done(self, point_id):
|
||
|
|
return point_id in self._completed
|
||
|
|
|
||
|
|
def mark_done(self, point_id):
|
||
|
|
with self._lock:
|
||
|
|
self._completed.add(point_id)
|
||
|
|
self._dirty += 1
|
||
|
|
if self._dirty >= 1000:
|
||
|
|
self._flush()
|
||
|
|
|
||
|
|
def _flush(self):
|
||
|
|
tmp = self.path.with_suffix('.tmp')
|
||
|
|
tmp.write_text(json.dumps({"completed": list(self._completed)}))
|
||
|
|
tmp.rename(self.path)
|
||
|
|
self._dirty = 0
|
||
|
|
|
||
|
|
def flush(self):
|
||
|
|
with self._lock:
|
||
|
|
self._flush()
|
||
|
|
|
||
|
|
def count(self):
|
||
|
|
return len(self._completed)
|
||
|
|
|
||
|
|
|
||
|
|
# ── Per-point processing ───────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def process_point(point, qdrant, collection, key_rotator, checkpoint, dry_run, stats):
|
||
|
|
point_id = point.id
|
||
|
|
if checkpoint.is_done(point_id):
|
||
|
|
return "skipped"
|
||
|
|
|
||
|
|
payload = point.payload
|
||
|
|
content = payload.get("content", payload.get("summary", ""))
|
||
|
|
summary = payload.get("summary", "")
|
||
|
|
subdomains = payload.get("subdomain", [])
|
||
|
|
if isinstance(subdomains, str):
|
||
|
|
subdomains = [subdomains]
|
||
|
|
old_domain = payload.get("domain", [])
|
||
|
|
if isinstance(old_domain, list):
|
||
|
|
old_domain_str = old_domain[0] if old_domain else "(empty)"
|
||
|
|
else:
|
||
|
|
old_domain_str = str(old_domain)
|
||
|
|
|
||
|
|
key = key_rotator.next()
|
||
|
|
new_domain = classify_domain(content, summary, subdomains, key)
|
||
|
|
|
||
|
|
# Track the mapping
|
||
|
|
stats_key = f"{old_domain_str} -> {new_domain}"
|
||
|
|
stats[stats_key] = stats.get(stats_key, 0) + 1
|
||
|
|
|
||
|
|
if dry_run:
|
||
|
|
return f"would: {old_domain_str} -> {new_domain}"
|
||
|
|
|
||
|
|
# Write new domain as single string
|
||
|
|
qdrant.set_payload(
|
||
|
|
collection_name=collection,
|
||
|
|
payload={"domain": new_domain},
|
||
|
|
points=[point_id],
|
||
|
|
)
|
||
|
|
|
||
|
|
checkpoint.mark_done(point_id)
|
||
|
|
return "ok"
|
||
|
|
|
||
|
|
|
||
|
|
# ── Main loop ───────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
SCROLL_BATCH = 5000
|
||
|
|
|
||
|
|
|
||
|
|
def count_source_domains(qdrant, collection):
|
||
|
|
"""Count vectors with source domains."""
|
||
|
|
counts = {}
|
||
|
|
for domain in SOURCE_DOMAINS:
|
||
|
|
result = qdrant.count(
|
||
|
|
collection_name=collection,
|
||
|
|
count_filter=Filter(
|
||
|
|
must=[FieldCondition(key="domain", match=MatchValue(value=domain))]
|
||
|
|
),
|
||
|
|
exact=True,
|
||
|
|
)
|
||
|
|
counts[domain] = result.count
|
||
|
|
return counts
|
||
|
|
|
||
|
|
|
||
|
|
def stream_and_process(qdrant, collection, rotator, checkpoint, workers, limit=None, dry_run=False):
|
||
|
|
"""Scroll source domains in batches, process with thread pool."""
|
||
|
|
lock = threading.Lock()
|
||
|
|
done = 0
|
||
|
|
skipped_checkpoint = 0
|
||
|
|
start = time.time()
|
||
|
|
stats = {} # shared mapping stats
|
||
|
|
|
||
|
|
for source_domain in sorted(SOURCE_DOMAINS):
|
||
|
|
log.info(f"\n--- Processing domain: {source_domain} ---")
|
||
|
|
offset = None
|
||
|
|
domain_done = 0
|
||
|
|
|
||
|
|
while True:
|
||
|
|
scroll_results, offset = qdrant.scroll(
|
||
|
|
collection_name=collection,
|
||
|
|
limit=SCROLL_BATCH,
|
||
|
|
with_payload=True,
|
||
|
|
with_vectors=False,
|
||
|
|
offset=offset,
|
||
|
|
scroll_filter=Filter(
|
||
|
|
must=[FieldCondition(key="domain", match=MatchValue(value=source_domain))]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
if not scroll_results:
|
||
|
|
if offset is None:
|
||
|
|
break
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Filter already checkpointed
|
||
|
|
pending = [p for p in scroll_results if not checkpoint.is_done(p.id)]
|
||
|
|
skipped_checkpoint += len(scroll_results) - len(pending)
|
||
|
|
|
||
|
|
if pending:
|
||
|
|
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||
|
|
futures = {
|
||
|
|
ex.submit(process_point, p, qdrant, collection, rotator,
|
||
|
|
checkpoint, dry_run, stats): p
|
||
|
|
for p in pending
|
||
|
|
}
|
||
|
|
for future in as_completed(futures):
|
||
|
|
try:
|
||
|
|
future.result()
|
||
|
|
except Exception as e:
|
||
|
|
log.error(f"Worker error: {e}")
|
||
|
|
with lock:
|
||
|
|
done += 1
|
||
|
|
domain_done += 1
|
||
|
|
if done % 5000 == 0:
|
||
|
|
elapsed = time.time() - start
|
||
|
|
rate = done / elapsed * 60
|
||
|
|
log.info(f" {done:,} done | {rate:.0f}/min | "
|
||
|
|
f"elapsed {elapsed/60:.1f}min")
|
||
|
|
checkpoint.flush()
|
||
|
|
time.sleep(0.02)
|
||
|
|
|
||
|
|
if limit and done >= limit:
|
||
|
|
break
|
||
|
|
if offset is None:
|
||
|
|
break
|
||
|
|
|
||
|
|
log.info(f" {source_domain}: {domain_done:,} vectors processed")
|
||
|
|
|
||
|
|
if limit and done >= limit:
|
||
|
|
break
|
||
|
|
|
||
|
|
checkpoint.flush()
|
||
|
|
return done, skipped_checkpoint, stats, start
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--dry-run", action="store_true",
|
||
|
|
help="Classify 20 samples without writing")
|
||
|
|
parser.add_argument("--workers", type=int, default=16)
|
||
|
|
parser.add_argument("--limit", type=int, default=None)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
keys = load_gemini_keys()
|
||
|
|
rotator = KeyRotator(keys)
|
||
|
|
|
||
|
|
qdrant = QdrantClient(host="localhost", port=6333, timeout=120)
|
||
|
|
collection = "recon_knowledge"
|
||
|
|
checkpoint = Checkpoint(CHECKPOINT_FILE)
|
||
|
|
|
||
|
|
# Count source domains
|
||
|
|
counts = count_source_domains(qdrant, collection)
|
||
|
|
total_source = sum(counts.values())
|
||
|
|
pre_checkpoint = checkpoint.count()
|
||
|
|
|
||
|
|
log.info(f"Source domain counts:")
|
||
|
|
for domain, count in sorted(counts.items(), key=lambda x: -x[1]):
|
||
|
|
log.info(f" {domain:30s} {count:>10,}")
|
||
|
|
log.info(f" {'TOTAL':30s} {total_source:>10,}")
|
||
|
|
log.info(f"Checkpoint: {pre_checkpoint:,} already completed")
|
||
|
|
log.info(f"Workers: {args.workers} | Keys: {len(keys)}")
|
||
|
|
|
||
|
|
# Cost estimate
|
||
|
|
remaining = total_source - pre_checkpoint
|
||
|
|
input_tokens = remaining * 200
|
||
|
|
output_tokens = remaining * 5
|
||
|
|
input_cost = input_tokens / 1_000_000 * 0.10
|
||
|
|
output_cost = output_tokens / 1_000_000 * 0.40
|
||
|
|
total_cost = input_cost + output_cost
|
||
|
|
log.info(f"\nEstimated Gemini 2.0 Flash cost:")
|
||
|
|
log.info(f" Vectors to process: {remaining:,}")
|
||
|
|
log.info(f" Input: ~{input_tokens/1_000_000:.1f}M tokens = ${input_cost:.2f}")
|
||
|
|
log.info(f" Output: ~{output_tokens/1_000_000:.1f}M tokens = ${output_cost:.2f}")
|
||
|
|
log.info(f" TOTAL: ~${total_cost:.2f}")
|
||
|
|
|
||
|
|
if args.dry_run:
|
||
|
|
log.info(f"\nDRY RUN: classifying 20 samples...\n")
|
||
|
|
for source_domain in sorted(SOURCE_DOMAINS):
|
||
|
|
scroll_results, _ = qdrant.scroll(
|
||
|
|
collection_name=collection,
|
||
|
|
limit=5,
|
||
|
|
with_payload=True,
|
||
|
|
with_vectors=False,
|
||
|
|
scroll_filter=Filter(
|
||
|
|
must=[FieldCondition(key="domain", match=MatchValue(value=source_domain))]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
for p in scroll_results[:4]:
|
||
|
|
pay = p.payload
|
||
|
|
title = pay.get("title", "(no title)")
|
||
|
|
content = pay.get("content", pay.get("summary", ""))
|
||
|
|
summary = pay.get("summary", "")
|
||
|
|
subdomains = pay.get("subdomain", [])
|
||
|
|
if isinstance(subdomains, str):
|
||
|
|
subdomains = [subdomains]
|
||
|
|
|
||
|
|
key = rotator.next()
|
||
|
|
new_domain = classify_domain(content, summary, subdomains, key)
|
||
|
|
|
||
|
|
old = pay.get("domain", [])
|
||
|
|
if isinstance(old, list):
|
||
|
|
old = old[0] if old else "?"
|
||
|
|
print(f" [{old:25s}] -> [{new_domain:25s}] {title[:60]}")
|
||
|
|
|
||
|
|
print(f"\nDRY RUN complete. ~{remaining:,} vectors would be migrated.")
|
||
|
|
print(f"Estimated cost: ~${total_cost:.2f}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# ── Full migration ──────────────────────────────────────────────────
|
||
|
|
log.info(f"\nStarting full migration...")
|
||
|
|
|
||
|
|
done, skipped_ckpt, stats, start = stream_and_process(
|
||
|
|
qdrant, collection, rotator, checkpoint, args.workers, args.limit
|
||
|
|
)
|
||
|
|
|
||
|
|
elapsed = time.time() - start
|
||
|
|
log.info(f"\n{'='*70}")
|
||
|
|
log.info(f"MIGRATION COMPLETE in {elapsed/60:.1f}min:")
|
||
|
|
log.info(f" Processed: {done:,}")
|
||
|
|
log.info(f" Skipped (checkpoint): {skipped_ckpt:,}")
|
||
|
|
log.info(f" Rate: {done/elapsed*60:.0f}/min")
|
||
|
|
log.info(f"\nMapping distribution:")
|
||
|
|
for mapping, count in sorted(stats.items(), key=lambda x: -x[1])[:30]:
|
||
|
|
log.info(f" {mapping:<55s} {count:>8,}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|