# -*- coding:utf-8 _*-
# @License: MIT Licence

# @Time: 24/5/2023

from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from transformers import Trainer
from transformers.modeling_utils import unwrap_model, PreTrainedModel
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.training_args import TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import EvalPrediction
from transformers.trainer_callback import TrainerCallback
from transformers.modeling_outputs import SequenceClassifierOutput


def get_normalized_probs(logits, log_probs):
    """Get normalized probabilities (or log probs) from a net's output."""
    return F.log_softmax(logits.float(), dim=-1) if log_probs else F.softmax(logits.float(), dim=-1)


def compute_kl_loss_Modified(src_logits, tgt_logits, pad_mask=None, reduction="batchmean"):
    p = get_normalized_probs(src_logits, log_probs=True)
    p_tec = get_normalized_probs(src_logits, log_probs=False)
    q = get_normalized_probs(tgt_logits, log_probs=True)
    q_tec = get_normalized_probs(tgt_logits, log_probs=False)

    p_loss = F.kl_div(p, q_tec, reduction=reduction)
    q_loss = F.kl_div(q, p_tec, reduction=reduction)

    if pad_mask is not None:
        raise NotImplementedError("pad_mask not implemented yet")
        pad_mask = pad_mask.T.unsqueeze(-1)
        p_loss.masked_fill_(pad_mask, 0.0)
        q_loss.masked_fill_(pad_mask, 0.0)

    return (p_loss + q_loss) / 2


def compute_js_loss_Modified(src_logits, tgt_logits, pad_mask=None, reduction="batchmean"):
    p = get_normalized_probs(src_logits, log_probs=True)
    p_tec = get_normalized_probs(src_logits, log_probs=False)
    q = get_normalized_probs(tgt_logits, log_probs=True)
    q_tec = get_normalized_probs(tgt_logits, log_probs=False)

    ave_tec = (p_tec + q_tec) / 2
    p_loss = F.kl_div(p, ave_tec, reduction=reduction)
    q_loss = F.kl_div(q, ave_tec, reduction=reduction)

    if pad_mask is not None:
        raise NotImplementedError("pad_mask not implemented yet")
        pad_mask = pad_mask.T.unsqueeze(-1)
        p_loss.masked_fill_(pad_mask, 0.0)
        q_loss.masked_fill_(pad_mask, 0.0)

    return (p_loss + q_loss) / 2


class Trainer_Modified(Trainer):
    def __init__(
            self,
            model: Union[PreTrainedModel, nn.Module] = None,
            args: TrainingArguments = None,
            data_collator: Optional[DataCollator] = None,
            train_dataset: Optional[Dataset] = None,
            eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
            tokenizer: Optional[PreTrainedTokenizerBase] = None,
            model_init: Optional[Callable[[], PreTrainedModel]] = None,
            compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
            callbacks: Optional[List[TrainerCallback]] = None,
            optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
            preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
            modified_aug_loss: str = "none",
            modified_aug_loss_weight: float = 0.0,
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics
        )

        self.modified_aug_loss = modified_aug_loss
        self.modified_aug_loss_weight = modified_aug_loss_weight

    def aug_loss_Modified(self, model, inputs, src_outs, ):
        '''
        implement the loss function of Modified
        Args:
            model:
            inputs:
            src_outs:

        Returns:

        '''
        if self.modified_aug_loss == "none":
            aug_loss = torch.tensor(0.)

        elif self.modified_aug_loss == "kl":
            with torch.no_grad():
                tgt_outs = model(**inputs, output_hidden_states=True)
            aug_loss = compute_kl_loss_Modified(src_outs["hidden_states"][-1][:, 0, :],
                                               tgt_outs["hidden_states"][-1][:, 0, :], reduction="batchmean")

        elif self.modified_aug_loss == "js":
            model_status = model.training
            with torch.no_grad():
                model.eval()
                tgt_outs = model(**inputs, output_hidden_states=True)
            if model_status:
                model.train()
            else:
                model.eval()
            aug_loss = compute_js_loss_Modified(src_outs["hidden_states"][-1][:, 0, :],
                                               tgt_outs["hidden_states"][-1][:, 0, :], reduction="batchmean")

        else:
            raise NotImplementedError(f"Unknown modified_js_kl_loss_strategy: {self.aug_loss_Modified}")

        return aug_loss

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        
        This is a copy of the original compute_loss function in Trainer class, but with the following changes:
        1. Add a new loss function: KL divergence loss
        
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs, output_hidden_states=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        aug_loss = self.aug_loss_Modified(model, inputs, src_outs=outputs).type_as(loss)
        loss = loss + aug_loss * self.modified_aug_loss_weight

        if model.training:
            return (loss, outputs) if return_outputs else loss
        else:
            outputs = SequenceClassifierOutput(
                loss=outputs.loss,
                logits=outputs.logits,
                hidden_states=None,
                attentions=outputs.attentions,
            )
            return (loss, outputs) if return_outputs else loss
