from trainer.utils import compute_dpo_loss
from trainer.unlearn.grad_diff import GradDiff

import torch.nn.functional as F
import torch

class ME(GradDiff):
    def __init__(self, beta=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.beta = beta
        if self.ref_model is None:
            self.ref_model = self._prepare_ref_model(self.model)

    def get_me_loss(self, model, inputs, labels):

        outputs = model(**inputs)
        logits = outputs.logits

        num_labels = logits.shape[-1]

        assert logits.shape[:-1] == labels.shape, "Logits and labels must have compatible shapes."

        # Adjust logits and labels to exclude the last token
        labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
        logits = logits[:, :-1, :]  # (bs, seq_len - 1, vocab_size)

        soft_outputs = F.softmax(logits, dim=-1).view(-1, num_labels)  # (bs*seq_len, vocab_size)
        uniform_dist = torch.full_like(soft_outputs, 1.0 / num_labels).to(logits.device)  # (bs*seq_len, vocab_size)

        loss_mask = (labels != -100).view(-1)  # (bs*(seq_len - 1))

        kl_div = F.kl_div((soft_outputs + 1e-12).log(), uniform_dist, reduction='none').sum(-1)  # (bs*(seq_len - 1))

        masked_kl_div = kl_div * loss_mask  # (bs*(seq_len - 1))
        loss = masked_kl_div.sum() / loss_mask.sum()

        return loss, outputs

    def compute_loss(self, model, inputs, return_outputs=False):
        forget_inputs = inputs["forget"]

        forget_inputs = {
            "input_ids": forget_inputs["input_ids"],
            "attention_mask": forget_inputs["attention_mask"],
            "labels": forget_inputs["labels"],
        }
        forget_labels = forget_inputs["labels"]
        forget_loss, forget_outputs = self.get_me_loss(model, forget_inputs, forget_labels)


        retain_inputs = inputs["retain"]
        retain_inputs = {
            "input_ids": retain_inputs["input_ids"],
            "attention_mask": retain_inputs["attention_mask"],
            "labels": retain_inputs["labels"],
        }
        retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)

        loss = self.gamma * forget_loss + self.alpha * retain_loss
        return (loss, forget_outputs) if return_outputs else loss