import torch
from torch import nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ------------------------------------
# Safe activation hooks for Qwen2.5
# ------------------------------------
def register_activation_hook(model):
    """
    Registers forward hooks on Qwen2.5 MLP blocks:
      model.model.layers[i].mlp
    Returns (activations_dict, handles_list).
    """
    activations = {}
    handles = []

    # Qwen2.5 in HF: base under model.model, layers list at model.model.layers
    base = getattr(model, "model", None)
    layers = getattr(base, "layers", []) if base is not None else []

    for i, layer in enumerate(layers):
        if hasattr(layer, "mlp"):
            name = f"model.layers.{i}.mlp"

            def _hook(mod, _inp, out, n=name):
                activations[n] = out

            handles.append(layer.mlp.register_forward_hook(_hook))

    return activations, handles


class MMD_loss(nn.Module):
    def __init__(self, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super().__init__()
        self.kernel_mul = kernel_mul
        self.kernel_num = kernel_num
        self.fix_sigma = fix_sigma

    def _gaussian_kernel(self, source, target):
        total = torch.cat([source, target], dim=0)  # [B1+B2, D]
        total0 = total.unsqueeze(0)  # [1, N, D]
        total1 = total.unsqueeze(1)  # [N, 1, D]
        L2 = ((total0 - total1) ** 2).sum(2)  # [N, N]

        if self.fix_sigma is not None:
            bandwidth = self.fix_sigma
        else:
            n = total.size(0)
            # average pairwise distance excluding diagonal
            bandwidth = (L2.sum() - L2.diag().sum()) / (n * (n - 1))
        bandwidth = bandwidth / (self.kernel_mul ** (self.kernel_num // 2))

        kernels = 0.0
        for i in range(self.kernel_num):
            kernels = kernels + torch.exp(-L2 / (bandwidth * (self.kernel_mul**i)))
        return kernels

    def forward(self, source, target):
        b = source.size(0)
        K = self._gaussian_kernel(source, target)
        XX = K[:b, :b]
        YY = K[b:, b:]
        XY = K[:b, b:]
        YX = K[b:, :b]
        return torch.mean(XX + YY - XY - YX)


def token_ce_loss(logits, labels):
    """
    Computes CE over next-token prediction
    """
    shift_logits = logits[..., :-1, :].contiguous().float()
    shift_labels = labels[..., 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
    per_tok = (
        loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        .view_as(shift_labels)
        .float()
    )

    valid_counts = (shift_labels != -100).sum().float()
    loss = per_tok.sum() / valid_counts
    return loss


def rep_noise_loss(model, harmful_batch, harmless_batch):
    alpha = 1.0
    beta = 1e-3

    mmd_loss = MMD_loss()
    activations, handles = register_activation_hook(model)

    # Harmful forward (we need logits + hidden_states)
    harmful_outputs = model(
        harmful_batch["input_ids"],
        attention_mask=harmful_batch.get("attention_mask", None),
        output_hidden_states=True,
    )

    # Collect harmful MLP activations captured by hooks
    harmful_activations = []
    base = getattr(model, "model", None)
    layers = getattr(base, "layers", []) if base is not None else []
    for i in range(len(layers)):
        key = f"model.layers.{i}.mlp"
        if key in activations:
            harmful_activations.append(activations[key])

    # ---- Representation noise via MMD against Gaussian on masked positions ----
    if len(harmful_activations) == 0:
        noise_term = torch.tensor(0.0, device=harmful_outputs.logits.device)
    else:
        noise_term = 0.0
        for hidden in harmful_activations:
            hiddens = hidden
            gaussian = torch.randn_like(hiddens)
            noise_term = noise_term + mmd_loss(
                hiddens.view(hiddens.size(0), -1), gaussian.view(gaussian.size(0), -1)
            )
        noise_term = noise_term / len(harmful_activations)

    # ---- Harmful losses: final logits + intermediate layers projected by lm_head ----
    harmful_losses = token_ce_loss(harmful_outputs.logits, harmful_batch["labels"])

    # Qwen2.5 final norm + lm_head
    output_embeddings = model.get_output_embeddings()  # lm_head
    final_norm = model.base_model.norm

    count = 1
    for h in harmful_outputs.hidden_states:
        out = output_embeddings(final_norm(h))
        harmful_losses += token_ce_loss(out, harmful_batch["labels"])
        count += 1
    harmful_losses = harmful_losses / count + 1.0  # stabilize log

    # Final negative-only rep-noise objective
    neg_loss = beta * noise_term - alpha * torch.log(harmful_losses)

    # Cleanup hooks
    for h in handles:
        h.remove()

    return neg_loss
