import copy
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn, Tensor
import torch.distributed as dist
from torch.nn.parameter import Parameter

import math


class vg_decoder_wrapper(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        args = cfg.copy()
        decoder_type= args.pop('type')
        self.decoder = _MODULES[decoder_type](**args)

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)


    def forward(self, img_feat, mask, pos_embed, word_feat, word_mask):
        hs = self.decoder(img_feat, mask, pos_embed,
                          word_feat, word_mask)
        return hs.transpose(1, 2)


class MultiStageDecoderLayer(nn.Module):
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1,
                 word_attn_args=None, img_attn_args=None, img_feat_chunk_num=2):
        super().__init__()
        args = word_attn_args.copy()
        self.word_attn = MULTIHEAD_ATTNS[args.pop('type')](**args)
        args = img_attn_args.copy()
        self.img_attn = MULTIHEAD_ATTNS[args.pop('type')](**args)
        # Implementation of Feedforward model
        self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
                                 nn.ReLU(inplace=True),
                                 nn.Dropout(dropout),
                                 nn.Linear(dim_feedforward, d_model))

        self.norm = _get_clones(nn.LayerNorm(d_model), 3)
        self.dropout = _get_clones(nn.Dropout(dropout), 3)

        self.img_feat_chunk_num = img_feat_chunk_num

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self, vis_query, vis_query_pos, text_query_pos,
                img_feat=None, img_key_padding_mask=None, img_pos=None,
                word_feat=None, word_key_padding_mask=None, word_pos=None, layer_idx=None):

        img_feat_k = img_feat_v = img_feat

        # Aggregate linguistic info about the object
        text_info = self.word_attn(query=self.with_pos_embed(vis_query, vis_query_pos),
                                   key=self.with_pos_embed(word_feat, word_pos),
                                   value=word_feat, key_padding_mask=word_key_padding_mask)[0]
        text_query = self.norm[0](self.dropout[0](text_info))

        # Gather visual feats based on the linguistic info
        vis_info = self.img_attn(query=self.with_pos_embed(text_query, text_query_pos),
                                 key=self.with_pos_embed(img_feat_k, img_pos),
                                 value=img_feat_v, key_padding_mask=img_key_padding_mask)[0]

        vis_query = self.norm[1](vis_query + self.dropout[1](vis_info))
        vis_query = self.norm[2](vis_query + self.dropout[2](self.ffn(vis_query)))

        return vis_query


class DecoderWithExtraEncoder(nn.Module):
    def __init__(self, num_queries, query_dim,
                 layer, num_layers, norm_dim, return_intermediate=False,
                 extra_layer=None, num_extra_layers=1):
        super().__init__()
        args = layer.copy()
        layer_type = args.pop('type')
        decoder_layer = _MODULES[layer_type](**args)
        self.layers = _get_clones(decoder_layer, num_layers)

        self.norm = nn.LayerNorm(norm_dim)
        self.return_intermediate = return_intermediate
        self.vis_query_embed = nn.Embedding(num_queries, query_dim)
        self.text_query_embed = nn.Embedding(num_queries, query_dim)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self, img_feat, img_key_padding_mask=None, pos=None,
                word_feat=None, word_key_padding_mask=None):

        intermediate = []
        hw, bs, c = img_feat.shape
        vis_query_embed = self.vis_query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
        text_query_embed = self.text_query_embed.weight.unsqueeze(1).repeat(1, bs, 1)

        # Initial target query
        vis_query = torch.zeros_like(vis_query_embed)

        # Multi-stage decoder
        for idx, layer in enumerate(self.layers):
            vis_query = layer(vis_query, vis_query_embed, text_query_embed,
                              img_feat, img_key_padding_mask, pos,
                              word_feat, word_key_padding_mask, None, idx)
            if self.return_intermediate:
                intermediate.append(self.norm(vis_query))


        output = vis_query
        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)

_MODULES = {
    'DecoderWithExtraEncoder': DecoderWithExtraEncoder,
    'MultiStageDecoderLayer': MultiStageDecoderLayer,
}

def build_vg_decoder(args):
    return vg_decoder_wrapper(args.model_config['decoder'])

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu_inplace":
        return nn.ReLU(inplace=True)
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

MULTIHEAD_ATTNS = {
    'MultiheadAttention': nn.MultiheadAttention,
}

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        if num_layers > 0:
            h = [hidden_dim] * (num_layers - 1)
            self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
        else:
            self.layers = []

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
