from typing import Callable, Dict, List, Optional, Tuple, Union, Any
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments


def create_topk_mask(grad, k, largest=True):
    seq_len = grad.shape[-1]
    k = min(k, seq_len)
    _, topk_indices = grad.topk(k, dim=-1, largest=largest)
    mask = torch.zeros_like(grad, dtype=torch.bool, device=grad.device)
    mask.scatter_(-1, topk_indices, True)
    return mask


class GradTrainer(Trainer):
    def __init__(self, model_args=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_args = model_args

    def backward_loss(self, loss, retain_graph=False, no_update=False):
        if self.do_grad_scaling:
            self.scaler.scale(loss).backward(retain_graph=retain_graph)
        elif no_update:
            self.accelerator.deepspeed_engine_wrapped.engine.backward(
                loss, retain_graph=retain_graph
            )
        else:
            self.accelerator.backward(loss, retain_graph=retain_graph)

    def cl_loss(self, attn_weights, token_mask):
        attn_weights_flat = attn_weights.view(-1, attn_weights.size(-1))
        mask_flat = ~token_mask.view(-1, token_mask.size(-1))
        log_attn_weights = torch.log(
            attn_weights_flat + torch.finfo(attn_weights_flat.dtype).eps
        )
        positive_neg_log_probs = -log_attn_weights.masked_fill(mask_flat, 0)
        # positive_neg_log_probs = -torch.masked_select(log_attn_weights, mask_flat.bool())
        # mask_sum = mask_flat.sum(dim=-1, keepdim=True)
        # robust_mask_sum = torch.where(
        #     mask_sum == 0, torch.ones_like(mask_sum), mask_sum
        # )
        # loss = positive_neg_log_probs.sum(dim=-1) / robust_mask_sum
        # loss = loss.mean()
        # return loss
        return positive_neg_log_probs.mean()

    def mse_loss(self, student_attn, teacher_attn):
        mse_loss_fn = nn.MSELoss(reduction="none")
        loss = mse_loss_fn(student_attn, teacher_attn)
        return loss

    def get_distill_loss(self, student_attn, teacher_attn, ele_mask):
        loss = self.mse_loss(student_attn, teacher_attn)
        loss = loss.masked_fill(~ele_mask, 0).sum() / ele_mask.sum()
        return loss

    def norm_grad(self, grad):
        grad = -grad
        if self.model_args.grad_norm == "l2":
            return F.normalize(grad, p=2, dim=-1)
        elif self.model_args.grad_norm == "l1":
            return F.normalize(grad, p=1, dim=-1)
        elif self.model_args.grad_norm == "minmax":
            grad_min, grad_max = grad.min(dim=-1, keepdim=True), grad.max(
                dim=-1, keepdim=True
            )
            return (grad - grad_min) / (grad_max - grad_min)
        elif self.model_args.grad_norm == "z":
            grad_mean, grad_std = grad.mean(dim=-1, keepdim=True), grad.std(
                dim=-1, keepdim=True
            )
            return (grad - grad_mean) / grad_std

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")
        token_mask = inputs.get("token_mask", None)
        if token_mask is not None:
            token_mask = token_mask[:, None, None, :]
            token_mask = ~token_mask
            # TODO equal to all q_mask, not question_solution
            # token_mask = token_mask[None, :, None, None, :]

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=True,
            return_dict=True,
        )
        attentions = outputs.attentions
        lm_loss = outputs.loss
        attentions = torch.stack(attentions, dim=0)
        self.backward_loss(lm_loss, retain_graph=True, no_update=True)
        grads = model.module.get_attn_grad(
            use_abs=self.model_args.rg_grad_abs, grad_type=self.model_args.rg_grad
        )
        attn_loss = 0.0
        for i, (grad, attention) in enumerate(zip(grads, attentions)):
            if self.model_args.rg_distill_loss == "cl":
                topk_mask = create_topk_mask(
                    grad,
                    k=self.model_args.rg_grad_topk,
                    largest=self.model_args.rg_grad_abs is not None,
                )
                if len(topk_mask.shape) == 3:
                    topk_mask = topk_mask[:, None, :, :].expand(
                        -1, attention.shape[1], -1, -1
                    )
                if token_mask is not None:
                    topk_mask = topk_mask.masked_fill(token_mask, 0)
                attn_loss += self.cl_loss(attention, topk_mask.detach())
            elif self.model_args.rg_distill_loss == "mse":
                if len(grad.shape) == 3:
                    grad = grad[:, None, :, :]
                teacher_attn = attention - 10 * grad
                attn_loss += self.get_distill_loss(
                    attention, teacher_attn.detach(), token_mask
                )

        attn_loss = attn_loss * self.model_args.rg_weight
        self.backward_loss(attn_loss)
        loss = lm_loss + attn_loss
        print(
            {
                "lm": outputs.loss.item(),
                "attn": attn_loss.item(),
                "total": loss.item(),
            }
        )

        return (loss, {"outputs": outputs}) if return_outputs else loss

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        self.model.clear_attn_grad()

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        # if self.args.n_gpu > 1:
        #     print("NONO")
        #     loss = loss.mean()  # mean() to average on multi-gpu parallel training

        # self.backward_loss(loss)

        return loss.detach() / self.args.gradient_accumulation_steps
