import math

import torch
from torch import nn
import torch.nn.functional as F

# from model.modeling_bart import BartDecoder
# from transformers.modeling_bart import BartDecoder


from transformers.modeling_bart import BartConfig, BartEncoder, PretrainedBartModel
from transformers.modeling_bart import _reorder_buffer, _make_linear_from_emb, _prepare_bart_decoder_inputs
from transformers.generation_utils import BeamHypotheses
from transformers.utils import logging
from torch.distributions import kl_divergence
from transformers.modeling_bart import Seq2SeqLMOutput
from transformers.generation_utils import top_k_top_p_filtering
from model.modeling_bart_v2 import BartDecoder


logger = logging.get_logger(__name__)

def _gelu_python(x):
    """
    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
    torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def gelu_new(x):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class LinearLayer(nn.Module):
    def __init__(self, input_size, output_size, nonlinear=None):
        super(LinearLayer, self).__init__()
        self.linear = nn.Linear(in_features=input_size, out_features=output_size)
        self.nonlinear = nonlinear

    def forward(self, input_data):
        if self.nonlinear is None:
            return self.linear(input_data)
        else:
            return self.nonlinear(self.linear(input_data))

class RecurrentLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RecurrentLayer, self).__init__()
        self.rnn_cell = nn.GRUCell(input_size=input_size, hidden_size=hidden_size)

    def forward(self, input_data, prev_state):
        return self.rnn_cell(input_data, prev_state)

class BartModel(PretrainedBartModel):
    def __init__(self, config: BartConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = BartEncoder(config, self.shared)
        self.decoder = BartDecoder(config, self.shared)

        self.init_weights()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        history_attention_mask=None,
        knowledge_attention_mask=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        context_past_key_values=None,
        knowledge_past_key_values=None,
        positive_past_key_values=None,
        negative_past_key_values=None,
        use_cache=True,
        debug=False
    ):
        if not use_cache: # training
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
                self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype
            )
        else: # test
            decoder_padding_mask, causal_mask = None, None

        assert decoder_input_ids is not None

        if encoder_outputs is None: # training
            encoder_outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False, return_dict=True
            )

        # context module
        context_decoder_outputs = self.decoder(
            decoder_input_ids, encoder_outputs[0], history_attention_mask, decoder_padding_mask,
            decoder_causal_mask=causal_mask, past_key_values=context_past_key_values, use_cache=use_cache,
            output_attentions=False, output_hidden_states=False, return_dict=True, adapter_type=None
        )
        context_hidden_state = context_decoder_outputs.last_hidden_state # [bs, tgt_len, model_dim]
        context_past = context_decoder_outputs.past_key_values

        # knoweldge module (knowledge understanding + knowledge copy)
        knowledge_decoder_outputs = self.decoder(
            decoder_input_ids, encoder_outputs[0], knowledge_attention_mask, decoder_padding_mask,
            decoder_causal_mask=causal_mask, past_key_values=knowledge_past_key_values, use_cache=use_cache,
            output_attentions=False, output_hidden_states=False, return_dict=True, adapter_type=None,
        )
        knowledge_hidden_state = knowledge_decoder_outputs.last_hidden_state # [bs, tgt_len, model_dim]
        knowledge_past = knowledge_decoder_outputs.past_key_values

        positive_decoder_outputs = self.decoder(
            decoder_input_ids, encoder_outputs[0], knowledge_attention_mask, decoder_padding_mask,
            decoder_causal_mask=causal_mask, past_key_values=positive_past_key_values, use_cacue=use_cache,
            output_attentions=False, output_hidden_states=False, return_dict=True, adapter_type='pos'
        )
        positive_hidden_state = positive_decoder_outputs.last_hidden_state
        positive_past = positive_decoder_outputs.past_key_values

        negative_decoder_outputs = self.decoder(
            decoder_input_ids, encoder_outputs[0], knowledge_attention_mask, decoder_padding_mask,
            decoder_causal_mask=causal_mask, past_key_values=negative_past_key_values, use_cache=use_cache,
            output_attentions=False, output_hidden_states=False, return_dict=True, adapter_type='neg'
        )
        negative_hidden_state = negative_decoder_outputs.last_hidden_state
        negative_past = negative_decoder_outputs.past_key_values

        # merge attention from heads
        # knowledge_attentions = torch.mean(torch.stack(torch.unbind(knowledge_attentions, dim=1), dim=-1), dim=-1) # [bs, tgt_len, src_len]
        # return tuple(v for v in [context_hidden_state, knowledge_hidden_state, knowledge_attentions, encoder_outputs[0], context_past, knowledge_past])
        return tuple(v for v in [context_hidden_state, knowledge_hidden_state, positive_hidden_state, negative_hidden_state,
                                 encoder_outputs[0], context_past, knowledge_past, positive_past, negative_past])

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    def get_output_embeddings(self):
        return _make_linear_from_emb(self.shared)  # make it on the fly


class BartForConditionalGeneration(PretrainedBartModel):
    # remove module embedding in boundary_feat
    base_model_prefix = "model"
    authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]

    def __init__(self, config: BartConfig):
        super().__init__(config)
        base_model = BartModel(config)
        self.model = base_model
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))

        self.num_modules = 4
        self.hidden_size = 128

        self.embed_module = nn.Embedding(num_embeddings=self.num_modules, embedding_dim=self.hidden_size)
        self.dense_module = LinearLayer(input_size=config.d_model + self.hidden_size, output_size=self.hidden_size, nonlinear=gelu_new)
        self.init_module_state = LinearLayer(input_size=config.d_model, output_size=self.hidden_size, nonlinear=gelu_new)
        self.update_module_state = RecurrentLayer(input_size=self.hidden_size, hidden_size=self.hidden_size)
        self.prior_module = LinearLayer(input_size=self.hidden_size, output_size=self.num_modules, nonlinear=None)
        # self.init_module_state = nn.Parameter(torch.Tensor(self.hidden_size))
        # nn.init.uniform_(self.init_module_state, -0.1, 0.1)

        self.prior_boundary = LinearLayer(input_size=config.d_model + self.hidden_size, output_size=2, nonlinear=None)
        # self.update_boundary_state = RecurrentLayer(input_size=config.d_model + self.hidden_size, hidden_size=self.hidden_size)
        # self.init_boundary_state = nn.Parameter(torch.Tensor(self.hidden_size))
        # nn.init.uniform_(self.init_boundary_state, -0.1, 0.1)

    def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        old_num_tokens = self.model.shared.num_embeddings
        new_embeddings = super().resize_token_embeddings(new_num_tokens)
        self.model.shared = new_embeddings
        self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)

    def calc_copy_dist(self, attn_dist, vocab_size, input_ids):
        # attn_dist: [bs, tgt_len, src_len]
        # input_ids: [bs, src_len]
        # todo: check gradient
        input_ids = input_ids.unsqueeze(1).repeat(1, attn_dist.size(1), 1) # [bs, tgt_len, src_len]
        output_shape = attn_dist.size()[:-1] + (vocab_size, )
        output = torch.zeros(output_shape, dtype=attn_dist.dtype, layout=attn_dist.layout, device=attn_dist.device)
        output.scatter_(dim=2, index=input_ids, src=attn_dist)
        return output

    def one_hot(self, indices, depth, dtype=None):
        output_shape = indices.size() + (depth,)
        dtype = dtype if dtype is not None else indices.dtype
        output = torch.zeros(output_shape, dtype=dtype, device=indices.device)
        index = indices.unsqueeze(-1)
        src = torch.ones_like(index, dtype=dtype)
        output.scatter_(dim=-1, index=index, src=src)
        return output

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        history_attention_mask=None,
        knowledge_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        labels=None,
        labels_m=None,
        labels_z=None,
        use_cache=None,
        debug=False,
        **unused,
    ):
        if past_key_values is None: # training or init test
            context_past_key_values=None
            knowledge_past_key_values = None
            positive_past_key_values = None
            negative_past_key_values = None
            prev_module = None # [bs], int
            prev_module_state = None
            prev_hidden_state = None
            prev_boundary = None # [bs], int
            prev_boundary_state = None # [bs, hidden_size], float
        else: # test
            context_past_key_values = past_key_values[0]
            knowledge_past_key_values = past_key_values[1]
            positive_past_key_values = past_key_values[2]
            negative_past_key_values = past_key_values[3]
            prev_module = past_key_values[4]
            prev_module_state = past_key_values[5]
            prev_hidden_state = past_key_values[6]
            prev_boundary = past_key_values[7]
            prev_boundary_state = past_key_values[8]

        context_hidden_state, knowledge_hidden_state, positive_hidden_state, negative_hidden_state, encoder_last_hidden_state, \
        context_past, knowledge_past, positive_past, negative_past = self.model(
            input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids,
            history_attention_mask=history_attention_mask, knowledge_attention_mask=knowledge_attention_mask, encoder_outputs=encoder_outputs,
            context_past_key_values=context_past_key_values, knowledge_past_key_values=knowledge_past_key_values,
            positive_past_key_values=positive_past_key_values, negative_past_key_values=negative_past_key_values,
            use_cache=use_cache, debug=debug,
        )
        history_logits = F.linear(context_hidden_state, self.model.shared.weight, bias=self.final_logits_bias)
        knowledge_logits = F.linear(knowledge_hidden_state, self.model.shared.weight, bias=self.final_logits_bias)
        positive_logits = F.linear(positive_hidden_state, self.model.shared.weight, bias=self.final_logits_bias)
        negative_logits = F.linear(negative_hidden_state, self.model.shared.weight, bias=self.final_logits_bias)

        if self.training:
            action_prob = self.one_hot(labels_z, depth=self.num_modules, dtype=knowledge_logits.dtype)  # [bs, tgt_len, 4]
            final_logits = action_prob[:, :, 0:1] * history_logits + action_prob[:, :, 1:2] * knowledge_logits + \
                           action_prob[:, :, 2:3] * positive_logits + action_prob[:, :, 3:4] * negative_logits
            encoder_last_hidden_state = encoder_last_hidden_state.detach()

            prior_module_list = []
            prior_boundary_list = []
            update_mask_list = []

            tgt_len = decoder_input_ids.size(1)

            # v3
            module_state = self.init_module_state(torch.mean(encoder_last_hidden_state, dim=1).detach())
            for t in range(tgt_len):
                # predict prior_z
                module_logits = self.prior_module(module_state)
                prior_module_list.append(module_logits)

                read_data = labels_m[:, t:t+1].float()
                cur_module = labels_z[:, t:t+1]
                last_module_state = self.embed_module(cur_module).squeeze(1)

                last_obs_state = action_prob[:, t, 0:1] * context_hidden_state[:, t].detach() + action_prob[:, t, 1:2] * knowledge_hidden_state[:, t].detach() + \
                                 action_prob[:, t, 2:3] * positive_hidden_state[:, t].detach() + action_prob[:, t, 3:4] * negative_hidden_state[:, t].detach()
                module_feat = self.dense_module(torch.cat([last_obs_state, last_module_state], dim=-1))
                module_state = read_data * self.update_module_state(module_feat, module_state) + (1 - read_data) * module_state

                boundary_logits = self.prior_boundary(torch.cat([last_obs_state, last_module_state], dim=-1))
                prior_boundary_list.append(boundary_logits)

                if len(update_mask_list) == 0:
                    update_mask_list.append((read_data == 1.) | (read_data == 0.))
                if t != tgt_len - 1:
                    update_mask_list.append(read_data == 1.)

            update_mask = torch.cat(update_mask_list, dim=1)

            prior_module_logits = torch.stack(prior_module_list, dim=1)
            prior_boundary_logits = torch.stack(prior_boundary_list, dim=1)

            padding_mask = labels != self.config.pad_token_id
            lm_mask = padding_mask & ((labels_z == 0) | (labels_z == 1))
            positive_mask = padding_mask & (labels_z == 2)
            negative_mask = padding_mask & (labels_z == 3)

            loss_fct = nn.CrossEntropyLoss()

            lm_logits = final_logits.masked_select(
                lm_mask.unsqueeze(2).expand_as(final_logits)).contiguous().view(-1, final_logits.size(-1))
            lm_labels = labels.masked_select(lm_mask)
            lm_loss = loss_fct(lm_logits, lm_labels)

            positive_logits = final_logits.masked_select(
                positive_mask.unsqueeze(2).expand_as(final_logits)).contiguous().view(-1, final_logits.size(-1))
            positive_labels = labels.masked_select(positive_mask)
            positive_loss = loss_fct(positive_logits, positive_labels)

            negative_logits = final_logits.masked_select(
                negative_mask.unsqueeze(2).expand_as(final_logits)).contiguous().view(-1, final_logits.size(-1))
            negative_labels = labels.masked_select(negative_mask)
            negative_loss = loss_fct(negative_logits, negative_labels)

            # update_mask = padding_mask & (labels_m == 1)
            update_mask = padding_mask & update_mask
            labels_z = labels_z.masked_select(update_mask)
            labels_m = labels_m.masked_select(padding_mask)
            prior_module_logits = prior_module_logits.masked_select(update_mask.unsqueeze(2).expand_as(prior_module_logits)).contiguous().view(-1, self.num_modules)
            prior_boundary_logits = prior_boundary_logits.masked_select(padding_mask.unsqueeze(2).expand_as(prior_boundary_logits)).contiguous().view(-1, 2)
            kl_abs_loss = loss_fct(prior_module_logits, labels_z)
            kl_mask_loss = loss_fct(prior_boundary_logits, labels_m)

            return tuple([v for v in [lm_loss, positive_loss, negative_loss, kl_abs_loss, kl_mask_loss]])
        else:
            if decoder_input_ids.size(1) == 1: # first step
                boundary_state = None
                cur_boundary = None
                module_state = self.init_module_state(torch.mean(encoder_last_hidden_state, dim=1))
                module_logits = self.prior_module(module_state)
                cur_module = torch.max(module_logits, dim=1)[1].unsqueeze(1)
            else:
                boundary_state = torch.cat([prev_hidden_state[:, -1], self.embed_module(prev_module).squeeze(1)], dim=-1)
                boundary_logits = self.prior_boundary(boundary_state)
                cur_boundary = torch.max(boundary_logits, dim=1)[1].unsqueeze(1)
                module_feat = torch.cat([prev_hidden_state[:, -1], self.embed_module(prev_module).squeeze(1)], dim=-1)
                module_feat = self.dense_module(module_feat)
                module_state = cur_boundary.float() * self.update_module_state(module_feat, prev_module_state) + (1 - cur_boundary.float()) * prev_module_state
                module_logits = self.prior_module(module_state)
                cur_module = torch.max(module_logits, dim=1)[1].unsqueeze(1)
                cur_module = cur_module.bool().long() * 3
            module_prob = self.one_hot(cur_module, depth=self.num_modules).float()
            final_logits = module_prob[:, :, 0:1] * history_logits + module_prob[:, :, 1:2] * knowledge_logits + \
                           module_prob[:, :, 2:3] * positive_logits + module_prob[:, :, 3:4] * negative_logits
            final_hidden_state = module_prob[:, :, 0:1] * context_hidden_state + module_prob[:, :, 1:2] * knowledge_hidden_state + \
                                 module_prob[:, :, 2:3] * positive_hidden_state + module_prob[:, :, 3:4] * negative_hidden_state

            # next step
            prev_module = cur_module
            prev_module_state = module_state
            prev_boundary = cur_boundary
            prev_boundary_state = boundary_state
            prev_hidden_state = final_hidden_state if prev_hidden_state is None else (
                torch.cat([prev_hidden_state, final_hidden_state], dim=1)
            )

            return Seq2SeqLMOutput(
                loss=None,
                logits=final_logits,
                past_key_values=(context_past, knowledge_past, positive_past, negative_past,
                                 prev_module, prev_module_state, prev_hidden_state, prev_boundary, prev_boundary_state),
                decoder_hidden_states=None,
                decoder_attentions=None,
                encoder_last_hidden_state=encoder_last_hidden_state,
                encoder_hidden_states=None,
                encoder_attentions=None,
            )

    def prepare_inputs_for_generation(
        self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
    ):
        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "history_attention_mask": kwargs["history_attention_mask"],
            "knowledge_attention_mask": kwargs["knowledge_attention_mask"],
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def adjust_logits_during_generation(self, logits, cur_len, max_length):
        if cur_len == 1 and self.config.force_bos_token_to_be_generated:
            self._force_token_ids_generation(logits, self.config.bos_token_id)
        elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
            self._force_token_ids_generation(logits, self.config.eos_token_id)
        return logits

    def _force_token_ids_generation(self, scores, token_id) -> None:
        """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
        scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")

    @staticmethod
    def _reorder_cache(past, beam_idx): # todo: modify for multi-past values
        reordered_past = []
        for layer_past in past:
            # get the correct batch idx from decoder layer's batch dim for cross and self-attn
            layer_past_new = {
                attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
            }
            reordered_past.append(layer_past_new)
        return reordered_past

    def get_encoder(self):
        return self.model.encoder

    def get_output_embeddings(self):
        return _make_linear_from_emb(self.model.shared)  # make it on the fly

    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        min_length,
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        no_repeat_ngram_size,
        bad_words_ids,
        pad_token_id,
        eos_token_id,
        batch_size,
        attention_mask,
        use_cache,
        model_kwargs,
    ):
        """Generate sequences for each example without beam search (num_beams == 1).
        All returned sequence are generated independantly.
        """
        # length of generated sentences / unfinished sentences
        unfinished_sents = input_ids.new(batch_size).fill_(1)
        sent_lengths = input_ids.new(batch_size).fill_(max_length)

        past = None
        predicts_z, predicts_m = None, None
        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
            )

            outputs = self(**model_inputs, return_dict=True)
            next_token_logits = outputs.logits[:, -1, :]

            scores = self.postprocess_next_token_scores(
                scores=next_token_logits,
                input_ids=input_ids,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bad_words_ids=bad_words_ids,
                cur_len=cur_len,
                min_length=min_length,
                max_length=max_length,
                eos_token_id=eos_token_id,
                repetition_penalty=repetition_penalty,
                batch_size=batch_size,
                num_beams=1,
            )

            # if model has past, then set the past variable to speed up decoding
            if "past_key_values" in outputs:
                past = outputs.past_key_values
            elif "mems" in outputs:
                past = outputs.mems

            predicts_z = past[4] if predicts_z is None else torch.cat([predicts_z, past[4]], dim=-1)
            predicts_m = past[7] if predicts_m is None else torch.cat([predicts_m, past[7]], dim=-1)

            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    scores = scores / temperature
                # Top-p/top-k filtering
                next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
                # Sample
                probs = F.softmax(next_token_logscores, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                # Greedy decoding
                next_token = torch.argmax(next_token_logits, dim=-1)

            # update generations and finished sentences
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

            # add token and increase length by one
            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
            cur_len = cur_len + 1

            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents.mul_((~eos_in_sents).long())

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

            # extend attention_mask for new generated input if only decoder
            if self.config.is_encoder_decoder is False:
                attention_mask = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )
        return input_ids, predicts_z, predicts_m
