# precompute_caches.py
# pip install numpy torch scikit-learn rouge-score tqdm

from __future__ import annotations
import argparse, json, math, os

from llm_client import load_embed_model
import numpy as np
from tqdm import tqdm
import torch
from data_utils import get_task
from kmeans_pytorch import kmeans

# rouge-score for ROUGE-L
from rouge import Rouge



def softmax_stable(x: np.ndarray, tau: float) -> np.ndarray:
    tau = max(tau, 1e-8)
    z = x / tau
    z = z - np.max(z)
    e = np.exp(z)
    return e / (np.sum(e) + 1e-12)


def compute_T_prototypes_from_IG(
    ex_embs: np.ndarray,                  # (N, d)
    ig,
    T_ids,
    tau: float = 1.0):
    """
    prototype p^T = sum_e softmax(IG(T,e)/tau) * h_e
    """
    protos = {}
    for T_id in tqdm(T_ids, desc="Compute T prototypes"):
        ig_vals = np.array([v[T_id] for v in ig], dtype=np.float64)
        w = softmax_stable(ig_vals, tau=tau).astype(np.float32)  # (m,)

        proto = (w[:, None] * ex_embs).sum(axis=0)  # (d,)

        protos[T_id] = proto.astype(np.float32)

    return protos


def compute_rouge_qe(
    queries,
    examples,
    out_path: str,
) -> None:
    rouge = Rouge()

    rows = []

    for qid in range(len(queries)):
        for eid in range(len(examples)):
            scores = rouge.get_scores(queries[qid]["input"], examples[eid]["input"])
            val= scores[0]['rouge-l']['f']
            rows.append({"q_id": qid, "e_id": eid, "rouge": round(float(val), 4)})


    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(rows, f, ensure_ascii=False)


def main():
    ap = argparse.ArgumentParser()
    # kmeans
    ap.add_argument("--kmeans_k", type=int, default=10)
    # prototype
    ap.add_argument("--proto_tau", type=float, default=1.0)
    # rouge
    ap.add_argument("--rouge_out", type=str, default="test_rouge_qe.jsonl")

    args = ap.parse_args()
    for task in ['gsm8k','gpqa','fp', 'xsum','date','salient']:

        out_dir = f"./data/{task}"
        examples = get_task(task, 'train')
        queries = get_task(task, 'test')
        instpath = f"./data/{task}/{task}_instrs.json"
        with open(instpath, "r", encoding="utf-8") as f:
            instrs = json.load(f)
        
        
        with open(os.path.join(out_dir, "train_ifgain.json"), "r", encoding="utf-8") as f:
            ig = json.load(f)
        ig_rows = [v for d in ig for (_, v) in d.items()]
        
        ex_ids = list(range(len(examples)))
        emb_model = load_embed_model("Qwen/Qwen3-Embedding-0.6B")
        docs = [c["input"] for c in examples]
        ex_emb = emb_model.encode(docs)
        
        
        
        # ---------- 1) kmeans centroids ----------
        emb=torch.tensor(ex_emb)# [N,d]
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        else:
            device = torch.device('cpu')
        _, centroids = kmeans(X=emb, num_clusters=args.kmeans_k, distance='euclidean', device=device)
        
        cent_path = os.path.join(out_dir, f"centroids_k{args.kmeans_k}.npz")
        np.savez(cent_path, centroids=np.array(centroids))
        print(f"[OK] saved centroids -> {cent_path}")
        
        # ---------- 2) T prototypes ----------
        ig_map = {(j, i): v for i, row in enumerate(ig_rows) for j, v in enumerate(row)}
        
        T_ids = list(range(len(instrs)))
        protos = compute_T_prototypes_from_IG(
            ex_embs=ex_emb,
            ig=ig_rows,
            T_ids=T_ids,
            tau=args.proto_tau,
        )
        
        instrs_out = []
        for id,t in enumerate(instrs):
            proto = protos.get(id, None)
            instrs_out.append({
                "id": id,
                "text": t,
                "proto": proto.tolist() if proto is not None else None,
            })
        instr_out_path = os.path.join(out_dir, "instructions_with_proto.jsonl")
        with open(instr_out_path, "w", encoding="utf-8") as f:
            json.dump(instrs_out, f, ensure_ascii=False)
        print(f"[OK] saved instructions with proto -> {instr_out_path}")

        # ---------- 3) ROUGE(q,e) ----------
        rouge_path = os.path.join(out_dir, args.rouge_out)
        compute_rouge_qe(
            queries=queries,
            examples=examples,
            out_path=rouge_path,
        )
        print(f"[OK] saved rouge -> {rouge_path}")

        print("[DONE]")


if __name__ == "__main__":
    main()
