import json,os
import numpy as np
import torch
from smile import build_store, ChannelComputer, DSFExpert, EDSFSurrogate, sim01_cos
from data_utils import get_task
from typing import Dict, List, Tuple, Optional, Iterable
from llm_client import load_embed_model, generate, resolve_model_id
from metrics import evaluate_batched_metric
from prompts import build_prompt
from tqdm import tqdm


SAVE_PATH="./smile_results/"

def append_jsonl(path: str, record: dict) -> None:
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")
        f.flush()



# -------------------------
# Greedy selection for inference
# -------------------------


@torch.no_grad()
def greedy_select_smile(
    model: EDSFSurrogate,
    chan: ChannelComputer,
    q_id: str,
    T_id: str,
    candidate_ids: List[str],
    K: int,
    device: str,
) -> List[str]:
    """
    Greedy maximization of f(q,T,S) with |S|=K.
    Since f is monotone submodular when each expert is DSF and aggregation is min/softmin proxy, greedy is natural.
    """
    model.eval()
    chosen: List[str] = []
    if len(candidate_ids) == 0 or K <= 0:
        return []


    # ---- 1) precompute per-candidate channel vectors (CPU -> GPU once)
    # ss channels
    ss_list = [chan.ss_modular_s(e_id) for e_id in candidate_ids]  # each (k,)
    # sq channels depend on q_id
    sq_list = [chan.sq_channels(q_id, e_id) for e_id in candidate_ids]  # each (2,)
    # st channels depend on T_id
    st_list = [chan.st_channels(T_id, e_id) for e_id in candidate_ids]  # each (2,)

    S_ss = torch.tensor(np.stack(ss_list), device=device, dtype=torch.float32)  # (N,k)
    S_sq = torch.tensor(np.stack(sq_list), device=device, dtype=torch.float32)  # (N,2)
    S_st = torch.tensor(np.stack(st_list), device=device, dtype=torch.float32)  # (N,2)
    N, k = S_ss.shape

    # ---- 2) maintain current modular sums m(S)
    m_ss = torch.zeros(k, device=device, dtype=torch.float32)  # (k,)
    m_sq = torch.zeros(2, device=device, dtype=torch.float32)  # (2,)
    m_st = torch.zeros(2, device=device, dtype=torch.float32)  # (2,)

    # ---- 3) helper: compute hard-min EDSF score given (B,U) modular sums
    def hard_min_score_batch(mss: torch.Tensor, msq: torch.Tensor, mst: torch.Tensor) -> torch.Tensor:
        """
        mss: (B,k), msq: (B,2), mst: (B,2)
        return: (B,)
        """
        f_ss = model.experts[0](mss)   # (B,)
        f_sq = model.experts[1](msq)   # (B,)
        f_st = model.experts[2](mst)   # (B,)
        fr = torch.stack([f_ss, f_sq, f_st], dim=-1)  # (B,3)
        return torch.min(fr, dim=-1).values

    cur = float(hard_min_score_batch(m_ss.unsqueeze(0), m_sq.unsqueeze(0), m_st.unsqueeze(0)).item())

    remaining = torch.ones(N, dtype=torch.bool, device=device)

    for _ in range(min(K, N)):
        idxs = torch.nonzero(remaining, as_tuple=False).squeeze(-1)
        if idxs.numel() == 0:
            break

        # ---- 4) vectorized: scores for all S + {e} where e in remaining
        mss_all = m_ss.unsqueeze(0) + S_ss[idxs]  # (M,k)
        msq_all = m_sq.unsqueeze(0) + S_sq[idxs]  # (M,2)
        mst_all = m_st.unsqueeze(0) + S_st[idxs]  # (M,2)

        val_all = hard_min_score_batch(mss_all, msq_all, mst_all)  # (M,)
        gain_all = val_all - cur

        best_local = torch.argmax(gain_all)
        best_idx = idxs[best_local].item()

        # ---- 5) commit the best
        chosen.append(candidate_ids[best_idx])
        remaining[best_idx] = False

        # incremental update
        m_ss = m_ss + S_ss[best_idx]
        m_sq = m_sq + S_sq[best_idx]
        m_st = m_st + S_st[best_idx]

        cur = float(hard_min_score_batch(m_ss.unsqueeze(0), m_sq.unsqueeze(0), m_st.unsqueeze(0)).item())

    return chosen

@torch.no_grad()
def hard_min_score(model, m_list):
    fr = []
    for expert, m in zip(model.experts, m_list):
        fr.append(expert(m))
    fr = torch.stack(fr, dim=-1)
    return torch.min(fr, dim=-1).values

@torch.no_grad()
def sQI(store, q_id: str, T_id: str) -> float:
    q = store.queries[q_id].emb
    proto = store.instructions[T_id].proto
    if proto is None:
        return 0.0
    return sim01_cos(q, proto)

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, nargs="+", default=["gsm8k"], help="task names: gsm8k gpqa fp xsum date salient")
    parser.add_argument("--model", type=str, default="qwen3-4b", help="LLM name [qwen3-4b,llama3.1-8b]")
    parser.add_argument("--K", type=int, default=10, help="#ICL examples")

    parser.add_argument("--tau", type=float, default=1.0, help="softmin temperature")
    parser.add_argument("--phi", type=str, default="log1p", choices=["log1p", "cap"])
    parser.add_argument("--cap_alpha", type=float, default=1.0)
    parser.add_argument("--kmeans_k", type=int, default=10)

    parser.add_argument("--ss_tau", type=float, default=0.1)
    parser.add_argument("--subset", type=int, default=100)

    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=15)
    parser.add_argument("--lr", type=float, default=5e-3)
    parser.add_argument("--l2_beta", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    backend_id = resolve_model_id(args.model)


    for task in args.tasks:
        store = build_store(task=task,query_type='test')

        z = np.load(os.path.join(f"./data/{task}", f"centroids_k{args.kmeans_k}.npz"))
        ss_centroids = z["centroids"].astype(np.float32)

        chan = ChannelComputer(
            store=store,
            ss_centroids=ss_centroids,
            ss_tau=args.ss_tau,
            normalize_ig_to_nonneg=True,
        )

        # Build experts: [ss?, sq, st] ; EDSF = min over them (train via softmin)
        experts = []

        experts.append(DSFExpert(num_channels=ss_centroids.shape[0], phi_kind=args.phi, cap_alpha=args.cap_alpha))
        experts.append(DSFExpert(num_channels=2, phi_kind=args.phi, cap_alpha=args.cap_alpha))  # sq: sim, rouge
        experts.append(DSFExpert(num_channels=2, phi_kind=args.phi, cap_alpha=args.cap_alpha))  # st: IG, proto

        model = EDSFSurrogate(experts=experts, tau=args.tau)


        ckpt_path = os.path.join(f"./model/{task}", f"{args.model}_ckpt_seed{args.seed}.pt")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt["model"])
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)

        demoset = get_task(task, 'train')
        testset = get_task(task, 'test')

        instpath = f"./data/{task}/{task}_instrs.json"
        with open(instpath, "r", encoding="utf-8") as f:
            insturcts = json.load(f)

        emb_model = load_embed_model("Qwen/Qwen3-Embedding-0.6B")
        demo_emb = emb_model.encode([c["input"] for c in demoset])
        q_emb = emb_model.encode([c["input"] for c in testset])


        preds=[]
        for q_id in tqdm(range(len(testset))):
            # candidate pool for this query

            if args.subset==-1:
                cand_ids = list(range(len(testset)))

            else:
                sims = emb_model.similarity(q_emb[q_id], demo_emb).cpu().numpy()[0]
                cand_ids = np.argsort(-sims)[:args.subset].tolist()


            best = None
            for T_id in range(len(insturcts)):
                S_T = greedy_select_smile(model, chan, q_id, T_id, cand_ids, args.K, device=device)
                m_list = []
                m_list.append(chan.set_modular_sum_ss(S_T).unsqueeze(0).to(device))
                m_list.append(chan.set_modular_sum_sq(q_id, S_T).unsqueeze(0).to(device))
                m_list.append(chan.set_modular_sum_st(T_id, S_T).unsqueeze(0).to(device))

                f_val = float(hard_min_score(model, m_list).item())
                total = sQI(store, q_id, T_id) + f_val

                if best is None or total > best["total"]:
                    best = {"T_id": T_id, "S": S_T, "total": total, "f": f_val}
            chosen=[demoset[i] for i in best["S"]]
            prompt = build_prompt(task=task,  instruction=insturcts[best["T_id"]], query=testset[q_id]["input"], examples=chosen)
            pred = generate(backend_id, prompt, max_new_tokens=512)
            preds.append(pred)
        print({"task": task, "method": "smile", "seed":args.seed, "performance": evaluate_batched_metric(task, testset, preds)})
        record = {
            "task": task,
            "seed": int(args.seed),
            "subset": int(args.subset),
            "performance": evaluate_batched_metric(task, testset, preds),
        }

        out_file = os.path.join(SAVE_PATH, f"{args.model}_{task}_results.jsonl")
        append_jsonl(out_file, record)



if __name__ == "__main__":
    main()
