import time
from sanitizer.policy import train_policy
from graph.runtime import build_runtime_graph
from bench.tasks.multihop_qa import load_qa, exact_match
from bench.tasks.rootcause import gen_logs, f1
from bench.tasks.reco import gen_pairs
from bench.attacks.mia_shifted import shifted_losses, mia_auc
from bench.attacks.attr_inference import run_attr_inference
from bench.attacks.emb_inversion import invert_via_nn

def main():
    train_policy()
    run_request = build_runtime_graph()
    # Multi-hop QA
    qa = load_qa(200); em=[]
    t0=time.time()
    for q, ref in qa[:100]:
        r = run_request(q)["reply"]; em.append(exact_match(r, ref))
    print("[QA] EM:", sum(em)/len(em), " time_s:", round(time.time()-t0,1))
    # Root-cause
    logs = gen_logs(200); preds=[]; t0=time.time()
    for q, ref in logs[:100]:
        r = run_request(q)["reply"]; preds.append(r)
    print("[RootCause] F1-proxy:", f1(preds,[r for _,r in logs[:100]]), " time_s:", round(time.time()-t0,1))
    # Rec proxy
    rec = gen_pairs(200); _ = [run_request(q) for q,_ in rec[:50]]
    # Attacks
    loss_in, loss_out = shifted_losses(lambda q: run_request(q)["reply"], qa[:200])
    print("[MIA-shift] AUC:", round(mia_auc(loss_in, loss_out), 3))
    samples = [(run_request(q)["reply"], i%2) for i,(q,_) in enumerate(qa[:200])]
    print("[Attr-Inf] AUC:", round(run_attr_inference(samples), 3))
    inv_score = invert_via_nn([q for q,_ in qa[:50]], qa[:200])
    print("[Emb-Inversion] ROUGE-L:", round(inv_score, 3))

if __name__ == "__main__":
    main()
