import torch
from transformers import Trainer


class FinetuneTrainer(Trainer):
    # Use to train the model to forget the content later

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Unpack pairs
        f_inputs = {
            "input_ids": inputs["forget_input_ids"],
            "attention_mask": inputs["forget_attention_mask"],
            "labels": inputs["forget_labels"],
        }
        r_inputs = {
            "input_ids": inputs["retain_input_ids"],
            "attention_mask": inputs["retain_attention_mask"],
            "labels": inputs["retain_labels"],
        }

        # forget output
        f_outputs = model(**f_inputs)
        f_logits = f_outputs.logits
        f_shift_logits = f_logits[..., :-1, :].contiguous()
        f_shift_labels = f_inputs["labels"][..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        f_loss = loss_fct(
            f_shift_logits.view(-1, f_shift_logits.size(-1)), f_shift_labels.view(-1)
        )

        # retain output
        r_outputs = model(**r_inputs)
        r_logits = r_outputs.logits
        r_shift_logits = r_logits[..., :-1, :].contiguous()
        r_shift_labels = r_inputs["labels"][..., 1:].contiguous()
        r_loss = loss_fct(
            r_shift_logits.view(-1, r_shift_logits.size(-1)), r_shift_labels.view(-1)
        )

        loss = f_loss + r_loss

        if return_outputs:
            return (loss, (f_outputs, r_outputs))
        else:
            return loss
