# distillation/kd.py
import math, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

# -----------------------------
# Configuration for distillation
# -----------------------------
@dataclass
class DistillConfig:
    T: float = 2.0                  # Temperature for soft targets
    alpha_ce: float = 0.1           # Weight of ground-truth cross-entropy loss
    alpha_kd: float = 0.5           # Weight of logits KD loss (KL or MSE)
    alpha_feat: float = 1.0         # Weight of feature matching losses
    w_qkv: float = 1.0              # Weight for Q/K/V projection taps
    w_att: float = 0.25             # Weight for attention-probabilities reconstruction loss
    w_head: float = 0.5             # Weight for attention-head output (context) loss
    w_mha: float = 1.0              # Weight for post-attention (projected) output loss
    w_ff1: float = 0.5              # Weight for MLP pre-downstream feature (ff1) loss
    w_ff2: float = 1.0              # Weight for MLP downstream feature (ff2) loss
    layers: Optional[List[int]] = None  # Which transformer layers to tap; None -> all
    use_kl: bool = False            # If True: KL-div for KD; else: MSE on logits/T         

# -------------------------------------------------------
# Utility: get the effective weight of linear projections
# Supports both dense (..._proj) and low-rank (..._u_proj @ ..._v_proj)
# -------------------------------------------------------
def effective_linear_weight(module, kind: str):
    if hasattr(module, f"{kind}_proj"):
        return getattr(module, f"{kind}_proj").weight
    if hasattr(module, f"{kind}_u_proj") and hasattr(module, f"{kind}_v_proj"):
        Wu = getattr(module, f"{kind}_u_proj").weight     # [out, r]
        Wv = getattr(module, f"{kind}_v_proj").weight     # [r, in]
        W_eff = Wu @ Wv                                   # [out, in]
        return W_eff
    raise RuntimeError(f"Unknown attention module type for kind={kind}")

def get_num_heads_and_dim(config):
    n_heads = config.num_attention_heads
    head_dim = config.hidden_size // n_heads
    return n_heads, head_dim

# -----------------------------
# FeatureTap: hook-based grabber
# -----------------------------
from typing import Optional, List
import torch

class FeatureTap:
    """
    Register forward/forward_pre hooks on student/teacher modules to capture:
      - Attention projections: Q, K, V, MHA(out-proj)
      - MLP features: ff1 (pre-down), ff2 (post-down)
    Data is stored in self.data as dict: "<layer_idx>/att/Q", "<layer_idx>/ff1", ...
    """
    def __init__(self, model,  layers: Optional[List[int]] = None,
                 store_device: str = "cpu", detach: bool = True):
        self.model = model
        self.store_device = store_device        # where to store tapped tensors
        self.detach = detach                    # detach from graph (True for teacher)

        self.layers = model.model.layers

        self.keep_layers = set(range(len(self.layers))) if layers is None else set(layers)
        self.clear()

        self._cfg = getattr(model, "config", None)
        # Register hooks
        self.handles = []
        for i, layer in enumerate(self.layers):
            if i not in self.keep_layers:
                continue
            # Try both naming styles (HF LLaMA vs custom)
            attn = getattr(layer, "self_attention", None) or getattr(layer, "self_attn", None)
            mlp  = getattr(layer, "mlp", None) or getattr(layer, "ffn", None)

            # ---- Attention taps ----
            if attn is not None:
                def _mk(tname, lname=i):
                    def _hook(mod, inp, out):
                        t = out
                        if self.detach: t = t.detach()
                        # Save on target device without blocking
                        self.data[f"{lname}/att/{tname}"] = t.to(self.store_device, non_blocking=True)
                    return _hook
                q_mod = getattr(attn, "q_u_proj", None) or getattr(attn, "q_proj", None)
                k_mod = getattr(attn, "k_u_proj", None) or getattr(attn, "k_proj", None)
                v_mod = getattr(attn, "v_u_proj", None) or getattr(attn, "v_proj", None)
                o_mod = getattr(attn, "o_u_proj", None) or getattr(attn, "o_proj", None)
                if q_mod is not None: self.handles.append(q_mod.register_forward_hook(_mk("Q")))
                if k_mod is not None: self.handles.append(k_mod.register_forward_hook(_mk("K")))
                if v_mod is not None: self.handles.append(v_mod.register_forward_hook(_mk("V")))
                if o_mod is not None: self.handles.append(o_mod.register_forward_hook(_mk("MHA")))
            # ---- MLP taps ----
            if mlp is not None:
                down_v = getattr(mlp, "down_v_proj", None)
                down_u = getattr(mlp, "down_u_proj", None)
                down   = getattr(mlp, "down_proj",   None)
                dense_up   = getattr(mlp, "dense_h_to_4h", None)   # up
                dense_down = getattr(mlp, "dense_4h_to_h", None)   # down
                # ff1: features BEFORE final down-projection
                trg_ff1 = down_v if down_v is not None else (down if down is not None else dense_down)
                if trg_ff1 is not None:
                    def _pre_ff1(mod, inp, lname=i):
                        t = inp[0]
                        if self.detach: t = t.detach()
                        self.data[f"{lname}/ff1"] = t.to(self.store_device, non_blocking=True)
                    self.handles.append(trg_ff1.register_forward_pre_hook(_pre_ff1))
                # ff2: features AFTER final down-projection
                trg_ff2 = down_u if down_u is not None else (down if down is not None else dense_down)
                if trg_ff2 is not None:
                    def _ff2(mod, inp, out, lname=i):
                        t = out
                        if self.detach: t = t.detach()
                        self.data[f"{lname}/ff2"] = t.to(self.store_device, non_blocking=True)
                    self.handles.append(trg_ff2.register_forward_hook(_ff2))

    def clear(self):
        """Clear captured features for the next forward pass."""
        self.data = {}

    def close(self):
        """Remove all hooks safely."""
        for h in self.handles:
            try:
                h.remove()
            except Exception:
                pass
        self.handles = []
# ---------------------------------------------------------
# Reconstruct attention probs/heads from flat Q, K, V taps
# ---------------------------------------------------------
def compute_attn_from_qkv(Q, K, V, config, attention_mask: Optional[torch.Tensor] = None):
    """
    Q, K, V: [B, T, C] flat projections; we reshape to [B, nH, T, d],
    compute masked SDPA, and return:
      - attn_probs: [B, nH, T, T]
      - heads:      [B, T, C] (concatenated across heads)
    """
    B, T, C = Q.shape
    nH, d = get_num_heads_and_dim(config)

    def _reshape(x):
        return x.view(B, T, nH, d).transpose(1, 2).contiguous()

    q = _reshape(Q).to(torch.float32)
    k = _reshape(K).to(torch.float32)
    v = _reshape(V).to(torch.float32)

    scale = 1.0 / math.sqrt(float(d))
    attn_logits = torch.matmul(q, k.transpose(-1, -2)) * scale  # [B,nH,T,T]

    big_neg = torch.tensor(-1e4, dtype=attn_logits.dtype, device=attn_logits.device)
    # Causal mask (lower triangular)
    causal = torch.ones((T, T), device=attn_logits.device, dtype=torch.bool).tril()
    attn_logits = attn_logits.masked_fill(~causal.view(1, 1, T, T), big_neg)
    # Apply padding mask (both on keys and queries)
    if attention_mask is not None:
        key_pad = (attention_mask == 0).view(B, 1, 1, T)         # [B,1,1,T]
        attn_logits = attn_logits.masked_fill(key_pad, big_neg)
        qry_pad = (attention_mask == 0).view(B, 1, T, 1)         # [B,1,T,1]
        attn_logits = attn_logits.masked_fill(qry_pad, big_neg)
    # Logits stabilization
    attn_logits = attn_logits - attn_logits.amax(dim=-1, keepdim=True)
    # Rows that are fully masked: keep zeros after softmax
    row_all_neg = (attn_logits <= big_neg * 0.9).all(dim=-1, keepdim=True)
    safe_logits = torch.where(row_all_neg, torch.zeros_like(attn_logits), attn_logits)

    attn_probs = torch.softmax(safe_logits, dim=-1, dtype=torch.float32)
    attn_probs = torch.where(row_all_neg, torch.zeros_like(attn_probs), attn_probs)

    heads = torch.matmul(attn_probs, v)                           # [B,nH,T,d]
    heads = heads.transpose(1, 2).contiguous().view(B, T, C)      # [B,T,C]

    out_dtype = Q.dtype
    return attn_probs.to(out_dtype), heads.to(out_dtype)

# -----------------------------
# Loss helpers
# -----------------------------
def mse(a, b):
    if a.shape != b.shape:
        raise ValueError(f"MSE shape mismatch: {tuple(a.shape)} vs {tuple(b.shape)}")
    return F.mse_loss(a, b)

def logits_kd_loss(student_logits, teacher_logits, T: float, use_kl: bool):
    """
    Logits distillation loss:
      - If use_kl: KL(student/T, teacher/T) * T^2
      - Else:      MSE(student/T, teacher)
    """
    if use_kl:
        s = F.log_softmax(student_logits / T, dim=-1)
        t = F.softmax(teacher_logits / T, dim=-1)
        return (F.kl_div(s, t, reduction='batchmean') * (T*T))
    else:
        return F.mse_loss(student_logits / T, teacher_logits)

import math
import torch
# ---------------------------------------------------------
# Feature-level losses (Q/K/V, attention probs, heads, MLP)
# ---------------------------------------------------------
def feature_losses(cfg: DistillConfig,
                   stu_feats: Dict[str, torch.Tensor],
                   tea_feats: Dict[str, torch.Tensor],
                   config, attention_mask: Optional[torch.Tensor]=None):
    """
    Aggregate feature loss over selected layers:
      - MSE(Q), MSE(K), MSE(V)
      - MSE(attention probabilities), MSE(attention head outputs)
      - MSE(MHA output projection), MSE(MLP ff1/ff2)
    Keys use the pattern: "<layer_idx>/att/Q", "<layer_idx>/ff1", ...
    """
    loss = 0.0
    keys = set(k.split('/')[0] for k in stu_feats.keys())
    def _mse_align(a, b): 
        b = b.to(a.device, non_blocking=True)
        if a.shape != b.shape:
            raise ValueError(f"MSE shape mismatch: {tuple(a.shape)} vs {tuple(b.shape)}")
        return F.mse_loss(a, b)

    for layer_idx in keys:
        Li = int(layer_idx)
        # Q/K/V projection feature alignment
        for name, w in [('Q', cfg.w_qkv), ('K', cfg.w_qkv), ('V', cfg.w_qkv)]:
            ks, kt = f"{Li}/att/{name}", f"{Li}/att/{name}"
            if ks in stu_feats and kt in tea_feats:
                loss = loss + w * _mse_align(stu_feats[ks], tea_feats[kt])
        # Reconstruct attention probs & head outputs when Q/K/V are present
        need = all(f"{Li}/att/{x}" in stu_feats for x in ("Q","K","V")) and \
               all(f"{Li}/att/{x}" in tea_feats for x in ("Q","K","V"))
        if need and (cfg.w_att>0 or cfg.w_head>0):
            sQ, sK, sV = (stu_feats[f"{Li}/att/Q"], stu_feats[f"{Li}/att/K"], stu_feats[f"{Li}/att/V"])
            tQ, tK, tV = (tea_feats[f"{Li}/att/Q"].to(sQ.device, non_blocking=True),
                          tea_feats[f"{Li}/att/K"].to(sK.device, non_blocking=True),
                          tea_feats[f"{Li}/att/V"].to(sV.device, non_blocking=True))
            
            s_att, s_heads = compute_attn_from_qkv(sQ, sK, sV, config, attention_mask) 
            t_att, t_heads = compute_attn_from_qkv(tQ, tK, tV, config, attention_mask) 
            if cfg.w_att>0:  loss = loss + cfg.w_att  * _mse_align(s_att,  t_att)
            if cfg.w_head>0: loss = loss + cfg.w_head * _mse_align(s_heads, t_heads)
        # Post-attention projection output
        if f"{Li}/att/MHA" in stu_feats and f"{Li}/att/MHA" in tea_feats:
            loss = loss + cfg.w_mha * _mse_align(stu_feats[f"{Li}/att/MHA"], tea_feats[f"{Li}/att/MHA"])
        # MLP features
        if f"{Li}/ff1" in stu_feats and f"{Li}/ff1" in tea_feats:
            loss = loss + cfg.w_ff1 * _mse_align(stu_feats[f"{Li}/ff1"], tea_feats[f"{Li}/ff1"])
        if f"{Li}/ff2" in stu_feats and f"{Li}/ff2" in tea_feats:
            loss = loss + cfg.w_ff2 * _mse_align(stu_feats[f"{Li}/ff2"], tea_feats[f"{Li}/ff2"])
    return loss


# -----------------------------
# Distiller module
# -----------------------------
class Distiller(nn.Module):
    """
    Wraps teacher & student models, collects features with FeatureTap,
    computes CE / KD / feature losses, and returns a dict of losses.
    """
    def __init__(self, teacher, student, args,kd_cfg: DistillConfig):
        super().__init__()
        self.teacher = teacher.eval()
        self.student = student.train()
        for p in self.teacher.parameters():
            p.requires_grad_(False)         # freeze teacher
        self.hf_config = student.config     # HuggingFace config for head dims
        self.kd_cfg = kd_cfg 
        # Teacher features: store on CPU, detached
        self.t_tap = FeatureTap(self.teacher, layers=kd_cfg.layers, store_device="cpu", detach=True)  
        # Student features: store on same device as student, keep grads
        _student_device = str(self.student.device) if isinstance(self.student.device, torch.device) else (
            self.student.device if hasattr(self.student, "device") else "cuda"
        )
        self.s_tap = FeatureTap(self.student, layers=kd_cfg.layers, store_device=_student_device, detach=False)



    def _normalize_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Flatten packed batches that may have shape [B, micro, T] to [B*micro, T]
        for compatibility across different dataloaders.
        """
        out = {}
        for k, v in batch.items():
            if isinstance(v, torch.Tensor) and v.dim() >= 3:
                out[k] = v.reshape(-1, v.size(-1))
            else:
                out[k] = v
        return out

    def forward(self, batch):
        # Normalize batch shapes (if needed)
        batch = self._normalize_batch(batch)
        # --- Forward teacher (no grad), collect features ---
        with torch.no_grad():
            self.t_tap.clear()
            t_out = self.teacher(**batch)          
            t_feats = dict(self.t_tap.data)
        # --- Forward student (with grad), collect features ---
        self.s_tap.clear()
        s_out = self.student(**batch)
        s_feats = dict(self.s_tap.data)
        losses = {}
        # 1) CE loss with hard labels (shifted tokens)
        if self.kd_cfg.alpha_ce > 0:
            labels = batch["input_ids"][:, 1:].contiguous()
            slogits = s_out.logits[:, :-1, :].contiguous()
            if "attention_mask" in batch:
                lbl_mask = batch["attention_mask"][:, 1:].contiguous()
                labels = labels.masked_fill(lbl_mask == 0, -100)
            losses["ce"] = F.cross_entropy(
                slogits.view(-1, slogits.size(-1)),
                labels.view(-1),
                ignore_index=-100
            )
        # 2) Logits KD loss (KL or MSE)
        if self.kd_cfg.alpha_kd > 0: 
            zs = s_out.logits
            zt = t_out.logits
            if self.kd_cfg.use_kl: 
                s = F.log_softmax(zs / self.kd_cfg.T, dim=-1)
                t = F.softmax(zt / self.kd_cfg.T, dim=-1)
                losses["kd"] = F.kl_div(s, t, reduction='batchmean') * (self.kd_cfg.T ** 2)
            else: 
                losses["kd"] = F.mse_loss(zs / self.kd_cfg.T, zt)
        # 3) Feature-level losses
        if self.kd_cfg.alpha_feat > 0:
            amask = batch.get("attention_mask", None)
            losses["feat"] = feature_losses(self.kd_cfg, s_feats, t_feats, self.hf_config, amask)
        # Total weighted loss
        total = 0.0
        total += self.kd_cfg.alpha_ce   * losses.get("ce", 0.0)
        total += self.kd_cfg.alpha_kd   * losses.get("kd", 0.0)
        total += self.kd_cfg.alpha_feat * losses.get("feat", 0.0)
        losses["total"] = total
        return losses

    def close(self):
        """Remove hooks and release tap resources."""
        self.t_tap.close()
        self.s_tap.close()
