
import json, sys, csv, numpy as np
from collections import Counter

def normalize(s): return "".join(ch.lower() for ch in (s or "").strip())
def em(gold, pred): return 1.0 if normalize(gold)==normalize(pred) else 0.0
def f1(gold, pred):
    g = normalize(gold).split(); p = normalize(pred).split()
    if not g and not p: return 1.0
    if not g or not p:  return 0.0
    cg = Counter(g) & Counter(p); num = sum(cg.values())
    if num==0: return 0.0
    prec = num / len(p); rec = num / len(g)
    return 2*prec*rec/(prec+rec+1e-12)

def token_support_rate(answer, retrieved):
    a = [t for t in normalize(answer).split() if t]
    if not a: return 1.0
    text = " " + normalize(retrieved or "") + " "
    hits = sum(1 for t in a if f" {t} " in text)
    return hits/len(a)

def load_preds(path):
    data=[]
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))
    return data

def sweep_threshold(data, n=200):
    u = np.array([d.get("u_score", 0.0) for d in data], dtype=float)
    ths = np.linspace(u.min()-1e-9, u.max()+1e-9, n)
    rows=[]
    for t in ths:
        kept=[d for d in data if d.get("u_score",0.0) <= t]  # keep=answer; else abstain
        cov = len(kept)/max(1,len(data))
        if len(kept)==0:
            m_em=m_f1=m_scr=0.0; docs=depth=toks=lat=0.0
        else:
            m_em = np.mean([em(d.get("gold_answer",""), d.get("pred_answer","")) for d in kept])
            m_f1 = np.mean([f1(d.get("gold_answer",""), d.get("pred_answer","")) for d in kept])
            m_scr= np.mean([token_support_rate(d.get("pred_answer",""), d.get("retrieved_text","")) for d in kept])
            docs=np.mean([d.get("docs_scored",0) for d in kept])
            depth=np.mean([d.get("rerank_depth",0) for d in kept])
            toks=np.mean([d.get("context_tokens",0) for d in kept])
            lat =np.mean([d.get("latency_ms",0) for d in kept])
        rows.append({
            "threshold":float(t),"coverage":cov,"EM":m_em,"F1":m_f1,"SCR":m_scr,
            "docs_scored":docs,"rerank_depth":depth,"context_tokens":toks,"latency_ms":lat
        })
    return rows

if __name__=="__main__":
    if len(sys.argv)<3:
        print("Usage: rc_eval.py preds.jsonl out.csv"); sys.exit(1)
    preds=sys.argv[1]; out=sys.argv[2]
    data=load_preds(preds)
    rows=sweep_threshold(data)
    with open(out,"w",newline="") as f:
        w=csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader(); w.writerows(rows)
    print("Wrote", out)
