from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

from typing import Optional

import torch
import math
from torch import nn, Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from npz_dataset import MAX_VIS_LEN, MAX_TEXT_LEN, MAX_POS_LEN
import torch.nn.functional as F


def init_net(m):
    if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Embedding)):
        nn.init.kaiming_normal_(m.weight)
    elif isinstance(m, nn.LayerNorm):
        m.bias.data.zero_()
        m.weight.data.fill_(1.0)
    elif isinstance(m, (nn.LSTM, nn.GRU)):
        for layer_p in m._all_weights:
            for p in layer_p:
                if 'weight' in p:
                    nn.init.normal_(m.__getattr__(p), 0.0, 0.02)
    elif (getattr(m, 'bias', None) is not None):
        nn.init.constant_(m.bias, 0)
    elif (getattr(m, 'no_initialization', None) is not None):
        return

    for child in m.children():
        if hasattr(child, '__class__') and (child.__class__.__name__ in ['T5Model', 'T5ForConditionalGeneration', 'T5EncoderModel', 'T5Stack']):
            continue
        init_net(child)


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


def _norm2(x: Tensor, dim: int = -1, eps: float =1e-8) -> Tensor:
    return x / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)



class CompletionModal(nn.Module):
    def __init__(self, input_dim: int, output_dim: int,
                 hidden_dim: int = 1024,
                 num_layers: int = 24, text_att_type: int = 4,
                 text_feedforward_dim: int = 512, text_att_n_head: int = 16,
                 dropout: float = 0.3, layer_norm_eps: float = 1e-8):
        super(CompletionModal, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.input2hidden = nn.Linear(input_dim, hidden_dim)

        self.start_end_embedding = nn.Embedding(2, hidden_dim)
        self.start_end_embedding.weight.data.normal_(1, 0.02)  # Initialise scale at N(1, 0.02)

        self.n_multihead = text_att_n_head

        # T5 v1.1 encoder
        from transformers import T5ForConditionalGeneration
        self.encoder = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-large').get_encoder()
        self.encoder.embed_tokens = None
        self.encoder.is_decoder = False

        self.hidden2output = nn.Linear(hidden_dim, output_dim)

        self.text_att_type = text_att_type

        if self.text_att_type >= 3:
            self.multihead_attn = nn.MultiheadAttention(input_dim, text_att_n_head,
                                                        dropout=dropout, batch_first=True)
            self.multihead_attn_dropout = nn.Dropout(dropout)
            self.multihead_attn_norm = nn.LayerNorm(input_dim, eps=layer_norm_eps)

            self.ff_linear1 = nn.Linear(input_dim, text_feedforward_dim)
            self.ff_activation1 = _get_activation_fn('gelu')
            self.ff_dropout1 = nn.Dropout(dropout)
            self.ff_linear2 = nn.Linear(text_feedforward_dim, input_dim)
            self.ff_dropout2 = nn.Dropout(dropout)
            self.ff_norm = nn.LayerNorm(input_dim, eps=layer_norm_eps)

        self.dropout = nn.Dropout(dropout)
        init_net(self)


    # multihead attention block
    def _mha_block(self, q: Tensor, kv: Tensor,
                   attn_mask: Optional[Tensor],
                   key_padding_mask: Optional[Tensor],
                   with_skip: bool = True) -> Tensor:
        out = self.multihead_attn(q, kv, kv,
                                  attn_mask=attn_mask,
                                  key_padding_mask=key_padding_mask,
                                  need_weights=False)[0]
        out = self.multihead_attn_dropout(out)
        if with_skip:
            out = q + out
        return self.multihead_attn_norm(out)


    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        out = self.ff_linear2(self.ff_dropout1(self.ff_activation1(self.ff_linear1(x))))
        return self.ff_norm(x + self.ff_dropout2(out))


    def forward(self, vis_feat: Tensor, vis_seq_len: Tensor,
                begin_mask: Optional[Tensor], end_mask: Optional[Tensor],
                text_feat: Optional[Tensor], text_seq_len: Optional[Tensor]) -> Tensor:
        vis_feat = _norm2(vis_feat)

        if self.text_att_type == 0:
            input_seq = vis_feat
        elif self.text_att_type >= 3:
            # feed the text feature to the transformer as key/value
            text_feat = _norm2(text_feat)

            # vis_feat:  (B, vid_seq_len, 512) // (N,L,E)
            # text_feat: (B, text_seq_len, 512) // (N,S,E)
            # attn_mask: (N * numhead, vid_seq_len, text_seq_len) // (N⋅num_heads,L,S)
            # key_padding_mask: (N, text_seq_len)  // (N, S)

            max_text_len = MAX_TEXT_LEN

            # (B, text_seq_len) // (N, S): True is not allowed to attend
            key_padding_mask = (torch.arange(max_text_len, device=text_seq_len.device).reshape(1, max_text_len).repeat(len(text_seq_len), 1) >= text_seq_len.unsqueeze(-1))

            # # (B, vid_seq_len) // (N, L): True is not allowed to attend
            # In our case, if we generate mask on query side, then it would mask entire row.
            # So we should not pass query filter and just replicate the source side only
            attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.n_multihead, MAX_VIS_LEN, 1).reshape(-1, MAX_VIS_LEN, max_text_len)

            att_text_feat = self._mha_block(vis_feat, text_feat, attn_mask, key_padding_mask, self.text_att_type!=3)
            att_text_feat = self._ff_block(att_text_feat)
            input_seq = att_text_feat
        else:
            raise Exception('not implemented')

        embedded = self.input2hidden(input_seq)

        if (begin_mask is not None) and (end_mask is not None) and (self.start_end_embedding is not None):
            embedded = torch.where(begin_mask.unsqueeze(-1), self.start_end_embedding(torch.zeros_like(begin_mask, dtype=torch.int32)), embedded)
            embedded = torch.where(end_mask.unsqueeze(-1), self.start_end_embedding(torch.ones_like(end_mask, dtype=torch.int32)), embedded)

        vis_padding_mask = (torch.arange(MAX_VIS_LEN, device=vis_seq_len.device).reshape(1, MAX_VIS_LEN).repeat(len(vis_seq_len), 1) < vis_seq_len.unsqueeze(-1)).to(torch.float16)

        output_bth = self.encoder(inputs_embeds=embedded,
                                  attention_mask=vis_padding_mask,
                                  output_attentions=False,
                                  output_hidden_states=False).last_hidden_state

        output_bth = self.dropout(output_bth)
        # output_bth = [batch, vis_len, hidden_dim]

        prediction_bto = self.hidden2output(output_bth)

        return prediction_bto # removed tanh from here for more flexibility.
