import math
import os
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import extract_model_from_parallel as unwrap_model
from transformers import Trainer


# ---------- STE gumbel ----------
def gumbel_softmax_STE(
    logits: torch.Tensor, tau: float = 1.0, hard: bool = True, dim: int = -1
):
    # sample noise in float32 for stability, then cast back
    noise = torch.rand_like(logits, dtype=torch.float32)
    g = -torch.log(-torch.log(noise.clamp_min(1e-20)))
    g = g.to(logits.dtype)
    y_soft = F.softmax((logits + g) / tau, dim=dim)
    if not hard:
        return y_soft
    idx = y_soft.argmax(dim=dim, keepdim=True)
    y_hard = torch.zeros_like(logits).scatter_(dim, idx, 1.0)
    # straight-through estimator
    return (y_hard - y_soft).detach() + y_soft


def _answer_mask(labels: torch.Tensor, attn: torch.Tensor) -> torch.Tensor:
    return (labels != -100) & (attn == 1)


def _masked_mean(x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
    m = m.to(device=x.device, dtype=x.dtype).unsqueeze(-1)
    denom = m.sum(1).clamp_min(torch.finfo(x.dtype).eps)
    return (x * m).sum(1) / denom


class STEEnergyHead(nn.Module):
    """
    Scalar energy from:
      - prompt rep (masked-mean of last hidden, projected)
      - STE answer rep (E @ one-hots, pooled, projected)
      - |delta| + scalars (lengths, cosine)
    """

    def __init__(
        self,
        hidden_size: int,
        proj_size: int = 256,
        tau_gumbel: float = 1.0,
        topdown_mlp: int = 128,
    ):
        super().__init__()
        self.tau = tau_gumbel
        d = min(proj_size, hidden_size)
        self.q = nn.Linear(hidden_size, d, bias=False)
        self.k = nn.Linear(hidden_size, d, bias=False)
        self.head = nn.Sequential(
            nn.Linear(3 * d + 3, topdown_mlp),
            nn.ReLU(),
            nn.Linear(topdown_mlp, 1),
        )

    @torch.no_grad()
    def _cheap_lengths(self, prm_m: torch.Tensor, ans_m: torch.Tensor, dtype):
        lp = prm_m.sum(1).to(dtype).unsqueeze(-1)
        la = ans_m.sum(1).to(dtype).unsqueeze(-1)
        return lp, la

    def _ste_answer_rep(
        self,
        logits: torch.Tensor,  # [B,T,V]
        labels: torch.Tensor,  # [B,T]
        attn: torch.Tensor,  # [B,T]
        emb_weight: torch.Tensor,  # [V,H]
    ) -> torch.Tensor:
        B, T, V = logits.shape
        H = emb_weight.size(1)

        ans_m = _answer_mask(labels, attn)
        if not bool(ans_m.any().item()):
            return logits.new_zeros(B, H)

        logits_ans = logits[ans_m]  # [N,V]
        y_hat = gumbel_softmax_STE(logits_ans, tau=self.tau, hard=True, dim=-1)  # [N,V]
        emb = y_hat.to(emb_weight.dtype) @ emb_weight  # [N,H]

        lens = ans_m.sum(1)  # [B] int64
        idx = torch.arange(B, device=logits.device, dtype=torch.long).repeat_interleave(
            lens.clamp_min(1)
        )
        sum_rep = logits.new_zeros(B, H, dtype=emb.dtype)
        if emb.numel() > 0:
            sum_rep.index_add_(0, idx, emb)
        mean_rep = sum_rep / lens.clamp_min(1).to(sum_rep.dtype).unsqueeze(1)
        return mean_rep.to(logits.dtype)

    def forward(
        self,
        last_hidden: torch.Tensor,  # [B,T,H]
        attn: torch.Tensor,  # [B,T]
        labels: torch.Tensor,  # [B,T]
        logits: torch.Tensor,  # [B,T,V]
        emb_weight: torch.Tensor,  # [V,H]
    ) -> torch.Tensor:
        ans_m = _answer_mask(labels, attn)
        prm_m = (attn == 1) & (~ans_m)

        h_prompt = _masked_mean(last_hidden, prm_m)  # [B,H]
        h_answer = self._ste_answer_rep(logits, labels, attn, emb_weight)  # [B,H]

        qp = self.q(h_prompt)  # [B,d]
        ka = self.k(h_answer)  # [B,d]
        delta = (qp - ka).abs()  # [B,d]

        lp, la = self._cheap_lengths(prm_m, ans_m, dtype=qp.dtype)
        cos = (
            torch.cosine_similarity(qp, ka, dim=-1, eps=1e-8)
            .unsqueeze(-1)
            .to(qp.dtype)
            # TODO: similarity to the golden answer of the forget
        )

        feats = torch.cat([qp, ka, delta, lp, la, cos], dim=-1)  # [B,3d+3]
        return self.head(feats).squeeze(-1)  # [B]


class SELUTrainer(Trainer):
    def __init__(
        self,
        *args,
        tau_low: float = 0.0,
        tau_high: float = 1.0,
        margin: float = 0.5,
        lambda_e: float = 1.0,
        lambda_ce_retain: float = 1.0,
        lambda_couple: float = 0.5,
        lambda_calib: float = 0.1,
        alpha_calib: float = 1.0,
        proj_size: int = 256,
        tau_gumbel: float = 1.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        hidden_size = (
            getattr(self.model.config, "hidden_size", None)
            or getattr(self.model.config, "d_model", None)
            or getattr(self.model.config, "n_embd", None)
        )
        assert hidden_size is not None, "Cannot infer hidden size from model.config"

        self.ste_head = STEEnergyHead(
            hidden_size, proj_size=proj_size, tau_gumbel=tau_gumbel
        )
        # register under the model so DeepSpeed sees the params
        setattr(self.model, "ste_head", self.ste_head)
        for p in self.model.ste_head.parameters():
            p.requires_grad_(True)

        self.tau_low = tau_low
        self.tau_high = tau_high
        self.margin = margin
        self.lambda_e = lambda_e
        self.lambda_ce_retain = lambda_ce_retain
        self.lambda_couple = lambda_couple
        self.lambda_calib = lambda_calib
        self.alpha_calib = alpha_calib

        try:
            self.model.config.use_cache = False
        except Exception:
            pass
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass

    def save_model(
        self, output_dir: Optional[str] = None, _internal_call: bool = False
    ):
        super().save_model(output_dir, _internal_call)
        if not self.args.should_save:
            return
        out = output_dir or self.args.output_dir
        torch.save(self.model.ste_head.state_dict(), os.path.join(out, "ste_head.pt"))

    # ---- helpers ----
    def _forward_lm(self, ids, mask, labels):
        out = self.model(
            input_ids=ids,
            attention_mask=mask,
            labels=labels,
            output_hidden_states=True,  # we need last_hidden
            return_dict=True,
        )
        last_hidden = out.hidden_states[-1]  # [B,T,H]
        return last_hidden, out.logits, out.loss

    def _ensure_head_device(self, ref: torch.Tensor):
        p = next(self.model.ste_head.parameters(), None)
        if (p is None) or (p.device != ref.device or p.dtype != ref.dtype):
            self.model.ste_head.to(device=ref.device, dtype=ref.dtype)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Unwrap DS/DP/FSDP to access the embedding matrix
        base = unwrap_model(model)
        get_emb = getattr(base, "get_input_embeddings", None)
        if get_emb is None:
            raise RuntimeError("Underlying model has no get_input_embeddings()")
        emb_weight = get_emb().weight  # [V,H]

        # ---------- RETAIN ----------
        R_h, R_logits, ce_retain = self._forward_lm(
            inputs["retain_input_ids"],
            inputs["retain_attention_mask"],
            inputs["retain_labels"],
        )
        self._ensure_head_device(R_h)
        E_r = self.model.ste_head(
            R_h,
            inputs["retain_attention_mask"],
            inputs["retain_labels"],
            R_logits,
            emb_weight,
        )  # [B]

        # ---------- FORGET ----------
        F_h, F_logits, _ = self._forward_lm(
            inputs["forget_input_ids"],
            inputs["forget_attention_mask"],
            inputs["forget_labels"],
        )
        self._ensure_head_device(F_h)
        E_f = self.model.ste_head(
            F_h,
            inputs["forget_attention_mask"],
            inputs["forget_labels"],
            F_logits,
            emb_weight,
        )  # [B]

        # (1) unary push-down/up
        L_down = torch.relu(E_r - self.tau_low).mean()
        L_up = torch.relu(self.tau_high - E_f).mean()

        # (2) pairwise margin
        L_pair = torch.relu(self.margin + E_r - E_f).mean()

        # (3a) retain calibration to token-NLL
        r_mask = inputs["retain_labels"] != -100
        L_cal = torch.tensor(0.0, device=R_h.device)
        if self.lambda_calib > 0 and bool(r_mask.any().item()):
            r_logits_ans = R_logits[r_mask]  # [N,V]
            r_targets = inputs["retain_labels"][r_mask]
            r_tok_nll = F.cross_entropy(r_logits_ans, r_targets, reduction="none")
            B = inputs["retain_labels"].size(0)
            lens = r_mask.sum(1).clamp_min(1)
            idx = torch.arange(
                B, device=r_tok_nll.device, dtype=torch.long
            ).repeat_interleave(lens)
            r_sum = torch.zeros(B, device=r_tok_nll.device, dtype=r_tok_nll.dtype)
            r_sum.scatter_add_(0, idx, r_tok_nll)
            r_nll = r_sum / lens.to(r_sum.dtype)
            L_cal = F.mse_loss(E_r, self.alpha_calib * r_nll)

        f_mask = inputs["forget_labels"] != -100
        L_cpl = torch.tensor(0.0, device=F_h.device)
        if self.lambda_couple > 0 and bool(f_mask.any().item()):
            f_logits_ans = F_logits[f_mask]  # [N,V]
            f_targets = inputs["forget_labels"][f_mask]
            f_probs = F.softmax(f_logits_ans, dim=-1)
            p_gold_tok = f_probs.gather(1, f_targets.unsqueeze(1)).squeeze(1)  # [N]
            Bf = inputs["forget_labels"].size(0)
            lens = f_mask.sum(1).clamp_min(1)
            idx = torch.arange(
                Bf, device=p_gold_tok.device, dtype=torch.long
            ).repeat_interleave(lens)
            p_sum = torch.zeros(Bf, device=p_gold_tok.device, dtype=p_gold_tok.dtype)
            p_sum.scatter_add_(0, idx, p_gold_tok)
            p_gold_mean = p_sum / lens.to(p_sum.dtype)
            L_cpl = (torch.sigmoid(E_f) - p_gold_mean).pow(2).mean()

        L_energy = L_down + L_up + L_pair
        total = (
            self.lambda_ce_retain * ce_retain
            + self.lambda_e * L_energy
            + self.lambda_couple * L_cpl
            # + self.lambda_calib * L_cal
        )

        if return_outputs:
            return total, {"E_r": E_r.detach(), "E_f": E_f.detach()}
        return total
