import torch
from transformers.trainer import Trainer, _is_peft_model


def compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None):
    pass


def per_token_loss(model_outputs, batch_input):
    pad_mask = batch_input["attention_mask"]
    logits = model_outputs.logits
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = batch_input["labels"][:, 1:].contiguous()
    label_pad_mask = pad_mask[:, 1:] != 0

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    bsz = shift_logits.size(0)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = loss.view(bsz, -1)
    size_per_sample = label_pad_mask.sum(axis=1)

    return (loss, label_pad_mask, size_per_sample)


def per_example_loss(model_outputs, batch_input):
    token_loss, label_pad_mask, size_per_sample = per_token_loss(
        model_outputs, batch_input
    )
    loss = token_loss.sum(axis=1) / size_per_sample
    return loss, label_pad_mask


def minmax_loss(model_outputs, batch_input):
    batch_loss = per_example_loss(model_outputs, batch_input)[0]
    gender_mask = batch_input["gender_bin"]
    male_losses = batch_loss[gender_mask == 0]
    female_losses = batch_loss[gender_mask == 1]
    loss_male = (
        male_losses.mean()
        if male_losses.numel() > 0
        else torch.tensor(0.0, device=batch_loss.device)
    )
    loss_female = (
        female_losses.mean()
        if female_losses.numel() > 0
        else torch.tensor(0.0, device=batch_loss.device)
    )
    return torch.max(loss_male, loss_female)


def ema_minmax_loss(model_outputs, batch_input, ema_factor=0.2, state={}):
    # check if in no grad mode
    if not torch.is_grad_enabled():
        return minmax_loss(model_outputs, batch_input)
    batch_loss = per_example_loss(model_outputs, batch_input)[0]
    # minimize the male or female that got the highest loss in the history
    gender_mask = batch_input["gender_bin"]
    male_losses = batch_loss[gender_mask == 0]
    female_losses = batch_loss[gender_mask == 1]
    loss_male = (
        male_losses.mean()
        if male_losses.numel() > 0
        else torch.tensor(0.0, device=batch_loss.device)
    )
    loss_female = (
        female_losses.mean()
        if female_losses.numel() > 0
        else torch.tensor(0.0, device=batch_loss.device)
    )
    if "male" not in state:
        state["male"] = loss_male
    if "female" not in state:
        state["female"] = loss_female
    state["male"] = ema_factor * state["male"] + (1 - ema_factor) * loss_male.detach()
    state["female"] = (
        ema_factor * state["female"] + (1 - ema_factor) * loss_female.detach()
    )
    max_index = 0 if state["male"] > state["female"] else 1
    if max_index == 0:
        return loss_male
    else:
        return loss_female


def reweighted_loss(model_outputs, batch_input):
    batch_loss = per_example_loss(model_outputs, batch_input)[0]
    weights = batch_input["weight"]
    weighted_loss = batch_loss * weights
    return weighted_loss.mean()


class CustomLossTrainer(Trainer):
    def __init__(
        self,
        loss_name="default",
        gender_weights=None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.loss_name = loss_name
        self.gender_weights = gender_weights
        if self.loss_name == "ema_minmax":
            self.state = {}

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if (
            self.label_smoother is not None or self.compute_loss_func is not None
        ) and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        if self.model_accepts_loss_kwargs:
            loss_kwargs = {}
            if num_items_in_batch is not None:
                loss_kwargs["num_items_in_batch"] = num_items_in_batch
            inputs = {**inputs, **loss_kwargs}
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # User-defined compute_loss function
            if self.loss_name == "default":
                loss = per_example_loss(outputs, inputs)[0].mean()
            elif self.loss_name == "minmax":
                loss = minmax_loss(outputs, inputs)
            elif self.loss_name == "reweighted":
                loss = reweighted_loss(outputs, inputs)
            elif self.loss_name == "ema_minmax":
                loss = ema_minmax_loss(outputs, inputs, state=self.state)
            else:
                raise ValueError(f"Unknown loss function {self.loss_name}")

        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
            loss *= self.accelerator.num_processes

        return (loss, outputs) if return_outputs else loss
