import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.checkpoint import checkpoint
import ipdb

from transformers import AutoModel,BertConfig,AutoTokenizer

from ..models.transformer_decoder import *
# from transformer_decoder import *


class TQN_Model(nn.Module):
    def __init__(self, cfg = None):
        super().__init__()
        embed_dim = cfg.model.fusion.d_model
        class_num = cfg.model.fusion.class_num
        class_num_plus = 3
        decoder_number_layer = cfg.model.fusion.decoder_number_layer
        self.experiment_name = cfg.experiment_name

        self.d_model = embed_dim
        decoder_layer = TransformerDecoderWoSelfAttenLayer(self.d_model, 4, 1024, 0.1, 'relu',normalize_before=True)
        self.decoder_norm = nn.LayerNorm(self.d_model)
        self.decoder = TransformerDecoder(decoder_layer, decoder_number_layer, self.decoder_norm, return_intermediate=False)
        self.dropout_feas = nn.Dropout(0.1)
        self.mlp_head = nn.Sequential( # nn.LayerNorm(768),
            nn.Linear(embed_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, class_num)
        )
        if 'sent_label_plus_with_CL_loss' in self.experiment_name or 'sent_label_plus_gl' in self.experiment_name or 'CL_sent_label_plus' in self.experiment_name:
            self.mlp_head_plus = nn.Sequential(  # nn.LayerNorm(768),
                nn.Linear(embed_dim, 1024),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(1024, 512),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(512, 256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(256, class_num_plus)
            )
        self.apply(self._init_weights)
    
    @staticmethod
    def _init_weights(module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)

        elif isinstance(module, nn.MultiheadAttention):
            module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
            module.out_proj.weight.data.normal_(mean=0.0, std=0.02)

        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
    
    def forward(self, image_features, text_features, pos=None, return_atten = False, inside_repeat=True, use_MLP=True):
        batch_size = image_features.shape[0]
        image_features = image_features.transpose(0,1)  #(patch_num,batch_size,dim)
        if inside_repeat:
            text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) # (query_num,batch_size,dim)
        image_features = self.decoder_norm(image_features)
        text_features = self.decoder_norm(text_features)
        features,atten_map = self.decoder(text_features, image_features, memory_key_padding_mask=None, pos=pos, query_pos=None)
        features = self.dropout_feas(features).transpose(0,1)  #b,embed_dim
        if use_MLP == False:
            return features
        if ('sent_label_plus_with_CL_loss' in self.experiment_name and 'CARZero' not in self.experiment_name) \
                or 'sent_label_plus_gl' in self.experiment_name or 'CL_sent_label_plus' in self.experiment_name:
            out = self.mlp_head_plus(features)
            if return_atten:
                return out, atten_map
            else:
                return out
        elif 'CARZero_sent_label_plus_with_CL_loss' in self.experiment_name:
            out_1 =  self.mlp_head(features)
            out_3 = self.mlp_head_plus(features)
            if return_atten:
                return out_1, out_3, atten_map
            else:
                return out_1, out_3
        else:
            out = self.mlp_head(features)  #(batch_size, query_num)
            if return_atten:
                return out, atten_map
            else:
                return out
        

# if __name__ == "__main__":
#     x_global = torch.ones(64, 768).cuda()
#     x_local = torch.ones(64, 64, 768).cuda()

#     model = TQN_Model().cuda()
#     with torch.no_grad():
#         model.eval()
#         out = model(x_local, x_global)
#         ipdb.set_trace()
