from transformers import Trainer


class GATrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        f_inputs = {
            "input_ids": inputs["forget_input_ids"],
            "attention_mask": inputs["forget_attention_mask"],
            "labels": inputs["forget_labels"],
        }
        # r_inputs = {
        #     "input_ids": inputs["retain_input_ids"],
        #     "attention_mask": inputs["retain_attention_mask"],
        #     "labels": inputs["retain_labels"],
        # }
        outputs = model(**f_inputs)
        forget_loss = outputs.loss * -1
        self.log({"forget_loss": outputs.loss.detach().item()})
        if return_outputs:
            return (forget_loss, outputs)
        else:
            return forget_loss
