
from __future__ import annotations
import json
import math
import random,os
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Iterable

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from data_utils import get_task
from llm_client import load_embed_model

def group_by_qid(data_path):
    with open(data_path ,"r", encoding="utf-8") as f:
        data = json.load(f)
    grouped = defaultdict(list)
    for q_id, T_id, K, S_idx, score, n_token in data:
        # grouped[q_id].append([T_id, K, S_idx, score, n_token])
        grouped[q_id].append([T_id, S_idx, score])
    return dict(grouped)

def build_pairs_for_small_candidate_lists(train_dict: dict, margin: float = 0.5):
    """
    train_dict[q_id] = [[T_id, S_ids(list), score], ...]
    output：[{q_id, T_pos, S_pos, T_neg, S_neg}, ...]
    margin: score_pos > score_neg + margin
    """
    pairs = []
    for q_id, cands in train_dict.items():
        if not cands or len(cands) < 2:
            continue

        # sorted score
        cands_sorted = sorted(cands, key=lambda x: float(x[2]), reverse=True)

        # all pairs larger than margin：i<j => i pos, j neg
        for i in range(len(cands_sorted)):
            T_i, S_i, s_i = cands_sorted[i][0], cands_sorted[i][1], float(cands_sorted[i][2])
            for j in range(i + 1, len(cands_sorted)):
                T_j, S_j, s_j = cands_sorted[j][0], cands_sorted[j][1], float(cands_sorted[j][2])

                if s_i <= s_j + margin:
                    continue

                pairs.append({
                    "q_id": q_id,
                    "T_pos": T_i,
                    "S_pos": list(S_i),
                    "T_neg": T_j,
                    "S_neg": list(S_j),
                })

    return pairs

# -------------------------
# Utilities
# -------------------------

def cosine_sim(a: np.ndarray, b: np.ndarray, eps: float = 1e-8) -> float:
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    return float(np.dot(a, b) / (na * nb + eps))

def sim01_cos(a: np.ndarray, b: np.ndarray) -> float:
    # (1 + cos)/2 in [0,1] as in Eq.(4)
    return 0.5 * (1.0 + cosine_sim(a, b))

def softmax_np(x: np.ndarray, tau: float = 1.0) -> np.ndarray:
    z = x / max(tau, 1e-8)
    z = z - np.max(z)
    e = np.exp(z)
    return e / (np.sum(e) + 1e-12)


# -------------------------
# Data formats
# -------------------------

@dataclass
class Example:
    id: str
    emb: np.ndarray              # h_i
    text: Optional[str] = None   # for ROUGE-L if needed

@dataclass
class Instruction:
    id: str
    text: str
    # optional per-instruction prototype p^T (computed from IG weights)
    proto: Optional[np.ndarray] = None

@dataclass
class Query:
    id: str
    emb: np.ndarray              # h_q
    text: Optional[str] = None   # for ROUGE-L if needed


class PreferencePairDataset(Dataset):
    """
    Each item:
      q_id, T_pos_id, S_pos(list of demo ids), T_neg_id, S_neg(list of demo ids)
    """
    def __init__(self, pairs: List[dict]):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx: int):
        return self.pairs[idx]


def collate_pairs(batch: List[dict]) -> dict:
    # keep lists; model will map ids to features
    return {
        "q_id": [b["q_id"] for b in batch],
        "T_pos": [b["T_pos"] for b in batch],
        "S_pos": [b["S_pos"] for b in batch],
        "T_neg": [b["T_neg"] for b in batch],
        "S_neg": [b["S_neg"] for b in batch],
    }


# -------------------------
# Modular channels (precomputed signals)
# -------------------------

class FeatureStore:
    """
    Holds everything needed to compute modular scores s_u(e) >= 0.

    Required:
      - examples: id -> Example(emb, text)
      - queries:  id -> Query(emb, text)
      - instructions: id -> Instruction(text, proto optional)
    Optional precomputed matrices:
      - rouge_qe[(q_id, e_id)] in [0,1]
      - IG[(T_id, e_id)]  (can be any real; we will clamp/shift to be nonnegative channels if needed)
    """
    def __init__(
        self,
        examples: Dict[str, Example],
        queries: Dict[str, Query],
        instructions: Dict[str, Instruction],
        rouge_qe: Optional[Dict[Tuple[str, str], float]] = None,
        IG_Te: Optional[Dict[Tuple[str, str], float]] = None,
    ):
        self.examples = examples
        self.queries = queries
        self.instructions = instructions
        self.rouge_qe = rouge_qe or {}
        self.IG_Te = IG_Te or {}

    def get_ex_emb(self, e_id: str) -> np.ndarray:
        return self.examples[e_id].emb

    def get_q_emb(self, q_id: str) -> np.ndarray:
        return self.queries[q_id].emb

    def get_T_proto(self, T_id: str) -> Optional[np.ndarray]:
        return self.instructions[T_id].proto

    def get_rouge(self, q_id: str, e_id: str) -> Optional[float]:
        return self.rouge_qe.get((q_id, e_id), None)

    def get_IG(self, T_id: str, e_id: str) -> Optional[float]:
        return self.IG_Te.get((T_id, e_id), None)


# -------------------------
# DSF building blocks
# -------------------------

class ConcaveActivation(nn.Module):
    """
    phi(x) is nondecreasing concave; two common choices in your draft: log(1+x) or cap min(alpha, x).
    """
    def __init__(self, kind: str = "log1p", alpha: float = 1.0):
        super().__init__()
        assert kind in ["log1p", "cap"]
        self.kind = kind
        self.alpha = alpha

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.kind == "log1p":
            return torch.log1p(torch.clamp_min(x, 0.0))
        else:
            return torch.minimum(torch.clamp_min(x, 0.0), torch.tensor(self.alpha, device=x.device, dtype=x.dtype))


class DSFExpert(nn.Module):
    """
    f(S) = sum_u w_u * phi_u( m_u(S) ), where m_u(S) = sum_{e in S} s_u(e), s_u(e) >= 0.
    """
    def __init__(self, num_channels: int, phi_kind: str = "log1p", cap_alpha: float = 1.0):
        super().__init__()
        self.num_channels = num_channels
        self.phi = ConcaveActivation(kind=phi_kind, alpha=cap_alpha)

        # w >= 0 via softplus(psi)
        self.psi = nn.Parameter(torch.zeros(num_channels))

    def weights(self) -> torch.Tensor:
        return F.softplus(self.psi)

    def forward(self, m: torch.Tensor) -> torch.Tensor:
        """
        m: (B, U) modular sums for a batch of sets
        returns f: (B,)
        """
        w = self.weights().unsqueeze(0)  # (1, U)
        return torch.sum(w * self.phi(m), dim=-1)


class EDSFSurrogate(nn.Module):
    """
    f(q,T,S) = min_r f^(r)(q,T,S) (EDSF as min over experts).
    Train with softmin: -tau log sum_r exp(-f_r / tau).
    """
    def __init__(self, experts: List[DSFExpert], tau: float = 0.2):
        super().__init__()
        assert len(experts) >= 1
        self.experts = nn.ModuleList(experts)
        self.tau = tau

    def softmin(self, fr: torch.Tensor) -> torch.Tensor:
        # fr: (B, R)
        tau = self.tau
        return -tau * torch.logsumexp(-fr / max(tau, 1e-8), dim=-1)

    def forward(self, m_list: List[torch.Tensor]) -> torch.Tensor:
        """
        m_list: list of modular-sum tensors, one per expert; each (B, U_r)
        returns: (B,) softmin score
        """
        fr = []
        for expert, m in zip(self.experts, m_list):
            fr.append(expert(m))
        fr = torch.stack(fr, dim=-1)  # (B, R)
        return self.softmin(fr)


# -------------------------
# Channel construction for the 3 experts in your draft
# -------------------------

class ChannelComputer:
    """
    Produces per-expert modular sums m_u(S) = sum_{e in S} s_u(e).
    We'll implement:
      - Sample-Sample expert (ss): k soft coverage buckets via kmeans centers.
      - Sample-Query expert (sq): emb sim + ROUGE-L.
      - Sample-Instruction expert (st): IG + proto similarity.
    """

    def __init__(
        self,
        store: FeatureStore,
        ss_centroids: Optional[np.ndarray] = None,  # (k, d)
        ss_tau: float = 0.1,
        normalize_ig_to_nonneg: bool = True,
        ig_shift_eps: float = 1e-6,
    ):
        self.store = store
        self.ss_centroids = ss_centroids
        self.ss_tau = ss_tau
        self.normalize_ig_to_nonneg = normalize_ig_to_nonneg
        self.ig_shift_eps = ig_shift_eps

    # ---- expert 1: sample-sample (coverage) ----
    def ss_modular_s(self, e_id: str) -> Optional[np.ndarray]:
        if self.ss_centroids is None:
            return None
        hi = self.store.get_ex_emb(e_id)
        # soft assignment via cosine similarity + temperature, normalized across all examples implicitly in paper
        # Here we just compute unnormalized exp(cos/τ) over centroids (k dims), a practical surrogate.
        sims = np.array([cosine_sim(hi, c) for c in self.ss_centroids], dtype=np.float32)
        a = softmax_np(sims, tau=self.ss_tau)  # (k,)
        a = np.clip(a, 0.0, None)             # nonnegative
        return a  # s_u(e) for u=1..k

    # ---- expert 2: sample-query ----
    def sq_channels(self, q_id: str, e_id: str) -> np.ndarray:
        hq = self.store.get_q_emb(q_id)
        hi = self.store.get_ex_emb(e_id)
        sim = sim01_cos(hi, hq)  # [0,1]
        rouge = self.store.get_rouge(q_id, e_id)
        if rouge is None:
            rouge = 0.0
        return np.array([sim, rouge], dtype=np.float32)

    # ---- expert 3: sample-instruction ----
    def st_channels(self, T_id: str, e_id: str) -> np.ndarray:
        # (i) IG channel
        ig = self.store.get_IG(T_id, e_id)
        if ig is None:
            ig = 0.0

        # Make channel nonnegative (DSF requirement). Common trick: shift by min or clamp at 0.
        if self.normalize_ig_to_nonneg:
            ig = max(0.0, ig)

        # (ii) proto similarity channel
        proto = self.store.get_T_proto(T_id)
        if proto is None:
            proto_sim = 0.0
        else:
            hi = self.store.get_ex_emb(e_id)
            proto_sim = sim01_cos(hi, proto)

        return np.array([ig + self.ig_shift_eps, proto_sim], dtype=np.float32)

    # ---- set -> modular sums m_u(S) ----
    def set_modular_sum_ss(self, S_ids: List[str]) -> torch.Tensor:
        assert self.ss_centroids is not None
        k = self.ss_centroids.shape[0]
        m = np.zeros((k,), dtype=np.float32)
        for e_id in S_ids:
            m += self.ss_modular_s(e_id)
        return torch.from_numpy(m)

    def set_modular_sum_sq(self, q_id: str, S_ids: List[str]) -> torch.Tensor:
        m = np.zeros((2,), dtype=np.float32)
        for e_id in S_ids:
            m += self.sq_channels(q_id, e_id)
        return torch.from_numpy(m)

    def set_modular_sum_st(self, T_id: str, S_ids: List[str]) -> torch.Tensor:
        m = np.zeros((2,), dtype=np.float32)
        for e_id in S_ids:
            m += self.st_channels(T_id, e_id)
        return torch.from_numpy(m)


# -------------------------
# Training (pairwise ranking)
# -------------------------

@dataclass
class TrainConfig:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size: int = 64
    lr: float = 5e-3
    weight_decay: float = 0.0
    l2_beta: float = 1e-3
    epochs: int = 15
    seed: int = 13


class SMILETrainer:
    def __init__(self, model: EDSFSurrogate, chan: ChannelComputer, cfg: TrainConfig):
        self.model = model.to(cfg.device)
        self.chan = chan
        self.cfg = cfg
        self.opt = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    def score_batch(self, q_ids: List[str], T_ids: List[str], S_list: List[List[str]]) -> torch.Tensor:
        # build m for each expert, then model softmin score
        m_ss, m_sq, m_st = [], [], []
        for q_id, T_id, S in zip(q_ids, T_ids, S_list):
            if self.chan.ss_centroids is not None:
                m_ss.append(self.chan.set_modular_sum_ss(S))
            m_sq.append(self.chan.set_modular_sum_sq(q_id, S))
            m_st.append(self.chan.set_modular_sum_st(T_id, S))

        m_list = []
        # Order of experts must match EDSFSurrogate.experts
        # Here we assume experts are [ss, sq, st] (you can change).
        idx = 0
        if self.chan.ss_centroids is not None:
            m_list.append(torch.stack(m_ss, dim=0).to(self.cfg.device))
            idx += 1
        m_list.append(torch.stack(m_sq, dim=0).to(self.cfg.device))
        idx += 1
        m_list.append(torch.stack(m_st, dim=0).to(self.cfg.device))
        idx += 1

        return self.model(m_list)  # (B,)

    def pairwise_loss(self, f_pos: torch.Tensor, f_neg: torch.Tensor) -> torch.Tensor:
        # Lrank = mean log(1 + exp(-(f_pos - f_neg)))
        return torch.mean(F.softplus(-(f_pos - f_neg)))

    def l2_reg(self) -> torch.Tensor:
        reg = torch.tensor(0.0, device=self.cfg.device)
        for m in self.model.experts:
            reg = reg + torch.sum(m.weights() ** 2)
        return reg

    def train_epoch(self, loader: DataLoader) -> Dict[str, float]:
        self.model.train()
        total_loss, total_rank = 0.0, 0.0
        n = 0

        for batch in loader:
            q_id = batch["q_id"]
            T_pos, S_pos = batch["T_pos"], batch["S_pos"]
            T_neg, S_neg = batch["T_neg"], batch["S_neg"]

            f_pos = self.score_batch(q_id, T_pos, S_pos)
            f_neg = self.score_batch(q_id, T_neg, S_neg)

            rank = self.pairwise_loss(f_pos, f_neg)
            loss = rank + self.cfg.l2_beta * self.l2_reg()

            self.opt.zero_grad(set_to_none=True)
            loss.backward()
            self.opt.step()

            bs = len(q_id)
            total_loss += float(loss.detach().cpu()) * bs
            total_rank += float(rank.detach().cpu()) * bs
            n += bs

        return {"loss": total_loss / max(n, 1), "rank": total_rank / max(n, 1)}

    @torch.no_grad()
    def evaluate_pair_acc(self, loader: DataLoader) -> float:
        self.model.eval()
        correct, total = 0, 0
        for batch in loader:
            q_id = batch["q_id"]
            T_pos, S_pos = batch["T_pos"], batch["S_pos"]
            T_neg, S_neg = batch["T_neg"], batch["S_neg"]
            f_pos = self.score_batch(q_id, T_pos, S_pos)
            f_neg = self.score_batch(q_id, T_neg, S_neg)
            correct += int(torch.sum((f_pos > f_neg)).item())
            total += f_pos.shape[0]
        return correct / max(total, 1)




def build_store(
    task: str,
    query_type: str = 'val',
    emb_model: str = "Qwen/Qwen3-Embedding-0.6B",
) -> FeatureStore:

    ex_raw = get_task(task, 'train')
    q_raw = get_task(task, query_type)

    emb_model = load_embed_model(emb_model)
    ex_docs = [c["input"] for c in ex_raw]
    ex_emb = emb_model.encode(ex_docs)
    q_docs = [c["input"] for c in q_raw]
    q_emb = emb_model.encode(q_docs)

    feat_dir = f"./data/{task}"

    with open(os.path.join(feat_dir, "instructions_with_proto.jsonl"), "r", encoding="utf-8") as f:
        t_raw = json.load(f)



    examples = {}
    for id,r in enumerate(ex_raw):
        examples[id] = Example(
            id=id,
            emb=ex_emb[id],
            text=r
        )

    queries = {}
    for id,r in enumerate(q_raw):
        queries[id] = Query(
            id=id,
            emb=q_emb[id],
            text=r
        )

    instructions = {}
    for r in t_raw:
        proto = np.array(r["proto"], dtype=np.float32) if "proto" in r and r["proto"] is not None else None
        instructions[r["id"]] = Instruction(
            id=r["id"],
            text=r["text"],
            proto=proto
        )

    rouge_qe = {}
    with open(os.path.join(feat_dir, f"{query_type}_rouge_qe.jsonl"), "r", encoding="utf-8") as f:
        rouge_raw = json.load(f)
    for r in rouge_raw:
        rouge_qe[(r["q_id"], r["e_id"])] = float(r["rouge"])

    with open(os.path.join(feat_dir, "train_ifgain.json"), "r", encoding="utf-8") as f:
        ig = json.load(f)
    ig_rows = [v for d in ig for (_, v) in d.items()]
    IG_Te = {(j, i): v for i, row in enumerate(ig_rows) for j, v in enumerate(row)}


    return FeatureStore(examples, queries, instructions, rouge_qe=rouge_qe, IG_Te=IG_Te)


# -------------------------
# Main
# -------------------------

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, nargs="+", default=["gsm8k"], help="task names: gsm8k gpqa fp xsum date salient")
    parser.add_argument("--model", type=str, default="qwen3-4b", help="LLM name [qwen3-4b,llama3.1-8b]")

    parser.add_argument("--tau", type=float, default=1.0, help="softmin temperature")
    parser.add_argument("--phi", type=str, default="log1p", choices=["log1p", "cap"])
    parser.add_argument("--cap_alpha", type=float, default=0.8)
    parser.add_argument("--kmeans_k", type=int, default=10)

    parser.add_argument("--ss_tau", type=float, default=0.1)

    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--l2_beta", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    for task in args.tasks:

        store = build_store(task=task)

        z = np.load(os.path.join(f"./data/{task}", f"centroids_k{args.kmeans_k}.npz"))
        ss_centroids = z["centroids"].astype(np.float32)

        chan = ChannelComputer(
            store=store,
            ss_centroids=ss_centroids,
            ss_tau=args.ss_tau,
            normalize_ig_to_nonneg=True,
        )

        # Build experts: [ss, sq, st] ; EDSF = min over them (train via softmin)
        experts = []

        experts.append(DSFExpert(num_channels=ss_centroids.shape[0], phi_kind=args.phi, cap_alpha=args.cap_alpha))
        experts.append(DSFExpert(num_channels=2, phi_kind=args.phi, cap_alpha=args.cap_alpha))  # sq: sim, rouge
        experts.append(DSFExpert(num_channels=2, phi_kind=args.phi, cap_alpha=args.cap_alpha))  # st: IG, proto

        model = EDSFSurrogate(experts=experts, tau=args.tau)

        cfg = TrainConfig(
            batch_size=args.batch_size,
            epochs=args.epochs,
            lr=args.lr,
            l2_beta=args.l2_beta,
            seed=args.seed,
        )

        margin = 2.0
        data_path = f"./data/{task}/{args.model}_data.json"
        train_dict = group_by_qid(data_path)
        if task == 'fp' and "qwen3" in data_path:
            margin = 0.5
        all_pairs = build_pairs_for_small_candidate_lists(train_dict, margin=margin)

        rng = random.Random(args.seed)
        rng.shuffle(all_pairs)
        n_train = int(len(all_pairs) * 0.7)
        train_pairs, dev_pairs= all_pairs[:n_train], all_pairs[n_train:]

        train_ds = PreferencePairDataset(train_pairs)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_pairs)

        dev_ds = PreferencePairDataset(dev_pairs)
        dev_loader = DataLoader(dev_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_pairs)

        trainer = SMILETrainer(model, chan, cfg)

        best_acc = -1.0
        save_model_path = os.path.join(f"./model/{task}", f"{args.model}_ckpt_seed{args.seed}.pt")
        for ep in range(cfg.epochs):
            stats = trainer.train_epoch(train_loader)
            msg = f"[epoch {ep+1}/{cfg.epochs}] loss={stats['loss']:.4f} rank={stats['rank']:.4f}"
            if dev_loader is not None:
                acc = trainer.evaluate_pair_acc(dev_loader)
                msg += f" dev_pair_acc={acc:.3f}"
                if acc >= best_acc:
                    best_acc = acc
                    torch.save({"model": model.state_dict()}, save_model_path)
                    msg += "  (saved best)"
            print(msg)


        print("Saved: smile_ckpt.pt")


if __name__ == "__main__":
    main()
