import torch
import torch.nn
from torch import tensor
from torchmetrics.functional.classification.confusion_matrix import (
    _multiclass_confusion_matrix_format,
)
from torchmetrics.functional.classification.hinge import (
    _hinge_loss_compute,
    _multiclass_hinge_loss_arg_validation,
    _multiclass_hinge_loss_tensor_validation,
)
from torchmetrics.utilities.data import to_onehot
from transformers import Trainer


def _custom_multiclass_hinge_loss_update(
    preds, target, alpha, squared, multiclass_mode="crammer-singer"
):
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.softmax(1)

    target = to_onehot(target, max(2, preds.shape[1])).bool()
    if multiclass_mode == "crammer-singer":
        margin = preds[target]
        margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0]
    else:
        target = target.bool()
        margin = torch.zeros_like(preds)
        margin[target] = preds[target]
        margin[~target] = -preds[~target]

    measures = alpha + margin
    measures = torch.clamp(measures, 0)

    if squared:
        measures = measures.pow(2)

    total = tensor(target.shape[0], device=target.device)
    return measures.sum(dim=0), total


def multiclass_hinge_loss(
    preds,
    target,
    num_classes,
    alpha=1.0,
    squared=False,
    multiclass_mode="crammer-singer",
    ignore_index=None,
    validate_args=True,
):
    if validate_args:
        _multiclass_hinge_loss_arg_validation(
            num_classes, squared, multiclass_mode, ignore_index
        )
        _multiclass_hinge_loss_tensor_validation(
            preds, target, num_classes, ignore_index
        )
    preds, target = _multiclass_confusion_matrix_format(
        preds, target, ignore_index, convert_to_labels=False
    )
    measures, total = _custom_multiclass_hinge_loss_update(
        preds,
        target,
        alpha,
        squared,
        multiclass_mode,
    )
    return _hinge_loss_compute(measures, total)


class IHLTrainer(Trainer):
    # Use to train the model to forget the content later

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Unpack pairs
        f_inputs = {
            "input_ids": inputs["forget_input_ids"],
            "attention_mask": inputs["forget_attention_mask"],
            "labels": inputs["forget_labels"],
        }
        r_inputs = {
            "input_ids": inputs["retain_input_ids"],
            "attention_mask": inputs["retain_attention_mask"],
            "labels": inputs["retain_labels"],
        }

        # forget output
        f_outputs = model(**f_inputs)

        scores = f_outputs.logits
        shift_logits = (
            scores[..., :-1, :].contiguous().squeeze().view(-1, scores.size(-1))
        )  # [BN, V]
        shift_labels = (
            f_inputs["labels"][..., 1:].contiguous().squeeze().view(-1)
        )  # [BN,]
        f_loss = multiclass_hinge_loss(
            shift_logits[shift_labels != -100, :],  # ignore pad tokens
            shift_labels[shift_labels != -100],
            shift_logits.size(-1),
        )

        # retain output
        r_outputs = model(**r_inputs)
        r_loss = r_outputs.loss

        loss = f_loss + r_loss
        logs = {
            "forget_loss": float(f_loss.detach().mean().cpu()),
            "retain_loss": float(r_loss.detach().mean().cpu()),
            "total_loss": float(
                loss.detach().cpu()
            ),  # log a copy, not the tensor used for backward
        }
        self.log(logs)
        # self.log(
        #     {
        #         "forget_loss": f_loss.mean().detach(),
        #         "retain_loss": r_loss.mean().detach(),
        #         "total_loss": loss.mean().detach(),
        #     }
        # )
        #
        if return_outputs:
            return (loss, (f_outputs, r_outputs))
        else:
            return loss
