import copy
import math
import os
import warnings
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import ACT2FN
from transformers.models.t5.modeling_t5 import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
from transformers.models.t5.modeling_t5 import PreTrainedModel
from transformers.models.t5.modeling_t5 import (
    ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
)
from transformers.models.t5.modeling_t5 import (
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_torch_fx_proxy,
    logging,
    replace_return_docstrings,
)
from transformers.models.t5.modeling_t5 import assert_device_map, get_device_map
from transformers.models.t5.modeling_t5 import T5Config
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import MSELoss, BCEWithLogitsLoss

from .modeling_t5 import ReduT5ForConditionalGeneration


class DistillWrapper(ReduT5ForConditionalGeneration):

    def __init__(self,
                 config: T5Config,
                 teacher: T5ForConditionalGeneration,
                 ):
        super().__init__(config)
        self.teacher = teacher
        self.do_distill = True
        self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

        # frozen teacher
        for param in self.teacher.parameters():
            param.requires_grad = False

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            decoder_input_ids: Optional[torch.LongTensor] = None,
            decoder_attention_mask: Optional[torch.BoolTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            decoder_head_mask: Optional[torch.FloatTensor] = None,
            cross_attn_head_mask: Optional[torch.Tensor] = None,
            encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            head_mask,
            decoder_head_mask,
            cross_attn_head_mask,
            encoder_outputs,
            past_key_values,
            inputs_embeds,
            decoder_inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict=True,
        )
        if self.training and self.do_distill:
            with torch.no_grad():
                teacher_outputs = self.teacher(
                    input_ids,
                    attention_mask,
                    decoder_input_ids,
                    decoder_attention_mask,
                    head_mask,
                    decoder_head_mask,
                    cross_attn_head_mask,
                    encoder_outputs,
                    past_key_values,
                    inputs_embeds,
                    decoder_inputs_embeds,
                    labels,
                    use_cache,
                    output_attentions,
                    output_hidden_states,
                    return_dict=True,
                )

            kl_loss = self.kl(
                torch.log_softmax(outputs.logits, dim=-1),
                torch.log_softmax(teacher_outputs.logits, dim=-1),
            ) / outputs.logits.shape[1]
            outputs.loss += kl_loss

        return outputs
