#!/usr/bin/env python3
"""
Lightweight empirical validation for "Procedural Library" theory (LLM navigability & hallucination decomposition)
================================================================================

What this does
--------------
Runs a tiny, controlled factual-QA experiment against a llama3.3 model (OpenAI-compatible chat API).
We evaluate 3 operator conditions:
  A) BASE    : zero-shot instruction
  B) FEWSHOT : prompt has 3 QA exemplars
  C) RAG     : retrieve a short support snippet (bag-of-words cosine over a tiny local corpus)

We report:
  - p_f (success probability = accuracy)
  - NI (Navigability Index): log improvement over BASE
  - HR decomposition: HR >= (1-c)*(1-alpha) + c*beta
      c     = coverage (retrieved snippet contains answer string)
      alpha = abstention rate ("I don't know"/"cannot answer" detection)
      beta  = conditional error given coverage and non-abstention

It also logs latency per call as a crude "energy per hit" proxy.

Requirements
------------
- Python 3.9+
- No external packages required (uses stdlib).
- Access to an OpenAI-compatible Chat Completions endpoint for llama3.3.

Configure via environment variables:
  LLM_API_KEY      : your API key
  LLM_API_BASE     : base URL (e.g., https://api.openai.com/v1  OR your gateway)
  LLM_MODEL        : model name (default: llama-3.3-instruct)
  LLM_PROVIDER     : "openai" (adds Bearer header) or "generic" (also Bearer, same path).

Run:
  python validate_procedural_library.py --trials 1

Output:
  - Prints a summary table to stdout
  - Writes results to validation_results.json
"""
import os, time, json, math, re, sys
from typing import List, Dict, Any, Tuple
from collections import Counter
import urllib.request, urllib.error

# ------------------ Config ------------------

API_KEY   = os.environ.get("LLM_API_KEY", "")
API_BASE  = os.environ.get("LLM_API_BASE", "https://api.openai.com/v1")
MODEL     = os.environ.get("LLM_MODEL", "llama-3.3-instruct")
PROVIDER  = os.environ.get("LLM_PROVIDER", "openai")  # "openai" or "generic"
TIMEOUT_S = 120

if not API_KEY:
    print("WARNING: LLM_API_KEY env var not set.", file=sys.stderr)

# ------------------ Tiny QA dataset ------------------

QA = [
    # question, answer, support_id
    ("What is the capital of Austria?", "Vienna", "capitals"),
    ("Who wrote the play 'Hamlet'?", "William Shakespeare", "hamlet"),
    ("What is the chemical symbol for water?", "H2O", "chem"),
    ("Which planet is known as the Red Planet?", "Mars", "mars"),
    ("Who proposed the theory of general relativity?", "Albert Einstein", "einstein"),
    ("What is the largest mammal on Earth?", "Blue whale", "whale"),
    ("What is the currency of Japan?", "Yen", "yen"),
    ("What gas do plants primarily absorb for photosynthesis?", "Carbon dioxide", "photosyn"),
    ("Which ocean is the deepest on average?", "Pacific Ocean", "ocean"),
    ("What is the primary language spoken in Brazil?", "Portuguese", "portuguese"),
    ("What instrument has keys, pedals, and strings and is often found in concert halls?", "Piano", "piano"),
    ("What do bees collect and use to make honey?", "Nectar", "nectar"),
]
QA_MAP = {q:a for (q,a,_) in QA}

# Short local "corpus" for RAG (id -> text).
CORPUS = {
    "capitals":    "Austria's capital and largest city is Vienna, located on the Danube.",
    "hamlet":      "'Hamlet' is a tragedy written by William Shakespeare.",
    "chem":        "Water is a molecule composed of hydrogen and oxygen with chemical formula H2O.",
    "mars":        "Mars is known as the Red Planet due to its iron oxide-rich surface.",
    "einstein":    "Albert Einstein proposed the theory of general relativity in the early 20th century.",
    "whale":       "The blue whale is the largest animal known to have ever existed.",
    "yen":         "The currency of Japan is the yen.",
    "photosyn":    "Plants absorb carbon dioxide and release oxygen during photosynthesis.",
    "ocean":       "The Pacific Ocean is the largest and also the deepest ocean on Earth on average.",
    "portuguese":  "In Brazil, the primary language spoken by the population is Portuguese.",
    "piano":       "A piano has keys, pedals, and strings; grand pianos are common in concert halls.",
    "nectar":      "Bees collect nectar from flowers and transform it into honey in their hives.",
}

# ------------------ Mini retriever (cosine BoW) ------------------

def tokenize(s: str) -> List[str]:
    return re.findall(r"[a-z0-9]+", s.lower())

def bow_vec(s: str) -> Counter:
    return Counter(tokenize(s))

def cosine(a: Counter, b: Counter) -> float:
    if not a or not b: return 0.0
    inter = set(a.keys()) & set(b.keys())
    num = sum(a[t] * b[t] for t in inter)
    den = math.sqrt(sum(v*v for v in a.values())) * math.sqrt(sum(v*v for v in b.values()))
    return (num / den) if den > 0 else 0.0

CORPUS_VECS = {k: bow_vec(v) for k, v in CORPUS.items()}

def retrieve(query: str, k: int = 1) -> List[Tuple[str, float]]:
    qv = bow_vec(query)
    scores = [(cid, cosine(qv, CORPUS_VECS[cid])) for cid in CORPUS]
    scores.sort(key=lambda x: x[1], reverse=True)
    return scores[:k]

# ------------------ Prompting ------------------

FEWSHOT_EXAMPLES = [
    ("What is the capital of France?", "Paris"),
    ("Which gas do humans need to breathe for survival?", "Oxygen"),
    ("What is 5 + 7?", "12"),
]

SYSTEM_BASE = "You are a careful, concise assistant. Answer with a short factual phrase. If unsure, say: I don't know."
SYSTEM_RAG  = "You are a careful, concise assistant. Use the attached SUPPORT to answer. If SUPPORT is insufficient, say: I don't know."

def make_fewshot_prompt() -> str:
    parts = ["Answer the question briefly. If unsure, say: I don't know.\n"]
    for q, a in FEWSHOT_EXAMPLES:
        parts.append(f"Q: {q}\nA: {a}\n")
    parts.append("Now answer the next question.\n")
    return "\n".join(parts)

def rag_context(support_texts: List[str]) -> str:
    joined = "\n\n".join(f"- {t}" for t in support_texts)
    return f"SUPPORT:\n{joined}\n\nUse only this support if possible."

def is_abstain(ans: str) -> bool:
    s = ans.strip().lower()
    return ("i don't know" in s) or ("cannot answer" in s) or ("not sure" in s)

def normalize(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip().lower())

def is_correct(ans: str, ref: str) -> bool:
    a = normalize(ans)
    r = normalize(ref)
    if r in a: return True
    aliases = {
        "vienna": ["wien"],
        "h2o": ["h₂o", "h20"],
        "blue whale": ["the blue whale"],
        "yen": ["jpy", "the yen"],
        "carbon dioxide": ["co2", "carbon-dioxide"],
        "portuguese": ["português"],
        "piano": ["grand piano", "upright piano"],
        "nectar": ["flower nectar"],
        "william shakespeare": ["shakespeare"],
        "pacific ocean": ["the pacific"],
        "albert einstein": ["einstein"],
        "mars": ["planet mars"],
    }
    for key, vals in aliases.items():
        if normalize(ref) == key and any(v in a for v in vals):
            return True
    return a == r

# ------------------ API call ------------------

def chat_completion(messages: List[Dict[str, str]], temperature: float=0.2, max_tokens: int=64) -> str:
    url = f"{API_BASE}/chat/completions"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}",
    }
    payload = {
        "model": MODEL,
        "messages": messages,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "n": 1,
    }
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(url, data=data, headers=headers, method="POST")
    with urllib.request.urlopen(req, timeout=120) as resp:
        res = json.loads(resp.read().decode("utf-8"))
    return res.get("choices", [{}])[0].get("message", {}).get("content", "")

# ------------------ Conditions ------------------

def run_base(q: str) -> Tuple[str, float]:
    msgs = [
        {"role":"system", "content": SYSTEM_BASE},
        {"role":"user", "content": q},
    ]
    t0 = time.time()
    out = chat_completion(msgs)
    dt = time.time() - t0
    return out, dt

def run_fewshot(q: str) -> Tuple[str, float]:
    msgs = [
        {"role":"system", "content": SYSTEM_BASE},
        {"role":"user", "content": make_fewshot_prompt() + f"\nQ: {q}\nA:"},
    ]
    t0 = time.time()
    out = chat_completion(msgs)
    dt = time.time() - t0
    return out, dt

def run_rag(q: str, k: int=1) -> Tuple[str, float, List[str], float]:
    top = retrieve(q, k=k)
    support_ids = [cid for cid, _ in top]
    supports = [CORPUS[cid] for cid in support_ids]
    msgs = [
        {"role":"system", "content": SYSTEM_RAG},
        {"role":"user", "content": rag_context(supports) + f"\nQ: {q}\nA:"},
    ]
    t0 = time.time()
    out = chat_completion(msgs)
    dt = time.time() - t0
    # Coverage c: if the retrieved support contains the gold answer string
    gold = QA_MAP[q]
    cov = 1.0 if any(normalize(gold) in normalize(s) for s in supports) else 0.0
    return out, dt, supports, cov

# ------------------ Runner ------------------

def main(trials: int=1, k: int=1):
    results = []
    base_correct = few_correct = rag_correct = 0
    base_lat = []; few_lat = []; rag_lat = []

    cov_list = []
    abst_list = []
    beta_count = 0
    beta_denom = 0

    for (q, ref, sid) in QA:
        # BASE
        b_ans, b_dt = run_base(q)
        base_lat.append(b_dt)
        b_abst = is_abstain(b_ans)
        b_ok   = (not b_abst) and is_correct(b_ans, ref)
        if b_ok: base_correct += 1

        # FEWSHOT
        f_ans, f_dt = run_fewshot(q)
        few_lat.append(f_dt)
        f_abst = is_abstain(f_ans)
        f_ok   = (not f_abst) and is_correct(f_ans, ref)
        if f_ok: few_correct += 1

        # RAG
        r_ans, r_dt, supports, cov = run_rag(q, k=k)
        rag_lat.append(r_dt)
        r_abst = is_abstain(r_ans)
        r_ok   = (not r_abst) and is_correct(r_ans, ref)
        if r_ok: rag_correct += 1

        cov_list.append(cov)
        abst_list.append(1.0 if r_abst else 0.0)
        if cov >= 0.5 and not r_abst:
            beta_denom += 1
            if not r_ok:
                beta_count += 1

        results.append({
            "question": q,
            "gold": ref,
            "base": {"answer": b_ans, "secs": b_dt, "abstain": b_abst, "correct": b_ok},
            "fewshot": {"answer": f_ans, "secs": f_dt, "abstain": f_abst, "correct": f_ok},
            "rag": {"answer": r_ans, "secs": r_dt, "abstain": r_abst, "correct": r_ok, "coverage": cov, "supports": supports},
        })

    n = len(QA)
    pf_base = base_correct / n
    pf_few  = few_correct  / n
    pf_rag  = rag_correct  / n

    def safe_log(x): 
        return float("-inf") if x <= 0 else math.log(x)

    NI_few = safe_log(pf_few) - safe_log(pf_base) if pf_base>0 else float('inf')
    NI_rag = safe_log(pf_rag) - safe_log(pf_base) if pf_base>0 else float('inf')

    c      = sum(cov_list)/n
    alpha  = sum(abst_list)/n
    beta   = (beta_count/beta_denom) if beta_denom>0 else 0.0
    HR_LB  = (1-c)*(1-alpha) + c*beta

    summary = {
        "N": n,
        "p_f": {"BASE": pf_base, "FEWSHOT": pf_few, "RAG": pf_rag},
        "NI":  {"FEWSHOT_vs_BASE": NI_few, "RAG_vs_BASE": NI_rag},
        "latency_sec_avg": {"BASE": sum(base_lat)/n, "FEWSHOT": sum(few_lat)/n, "RAG": sum(rag_lat)/n},
        "HR_decomposition_RAG": {"coverage_c": c, "abstention_alpha": alpha, "beta_error_given_coverage": beta, "HR_lower_bound": HR_LB},
    }

    print("\n=== SUMMARY ===")
    print(json.dumps(summary, indent=2))
    with open("validation_results.json", "w", encoding="utf-8") as f:
        json.dump({"summary": summary, "details": results}, f, indent=2, ensure_ascii=False)
    print("\nWrote validation_results.json")

if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("--trials", type=int, default=1, help="unused placeholder for future repeats")
    ap.add_argument("--k", type=int, default=1, help="RAG top-k (default 1)")
    args = ap.parse_args()
    main(trials=args.trials, k=args.k)
