import torch
import numpy as np
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from torchvision import models
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from miscc.config import cfg
from GlobalAttention import GlobalAttentionGeneral as ATT_NET
from GlobalAttention import GlobalAttention_text as ATT_NET_text
from spectral import SpectralNorm

class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * F.sigmoid(x[:, nc:])


def conv1x1(in_planes, out_planes, bias=False):
    "1x1 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=bias)


def conv3x3(in_planes, out_planes, stride=1, bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias)

class resD(nn.Module):
    def __init__(self, fin, fout, downsample=True):
        super().__init__()
        self.downsample = downsample
        self.learned_shortcut = (fin != fout)
        self.conv_r = nn.Sequential(
            nn.Conv2d(fin, fout, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(fout, fout, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv_s = nn.Conv2d(fin,fout, 1, stride=1, padding=0)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x, c=None):
        return self.shortcut(x)+self.gamma*self.residual(x)

    def shortcut(self, x):
        if self.learned_shortcut:
            x = self.conv_s(x)
        if self.downsample:
            return F.avg_pool2d(x, 2)
        return x

    def residual(self, x):
        return self.conv_r(x)



# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes * 2),
        nn.BatchNorm2d(out_planes * 2),
        GLU())
    return block


# Keep the spatial size
def Block3x3_relu(in_planes, out_planes):
    block = nn.Sequential(
        conv3x3(in_planes, out_planes * 2),
        nn.BatchNorm2d(out_planes * 2),
        GLU())
    return block


class ResBlock(nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            conv3x3(channel_num, channel_num * 2),
            nn.BatchNorm2d(channel_num * 2),
            GLU(),
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num))

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        return out


# ############## Text2Image Encoder-Decoder #######
class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken, ninput=300, drop_prob=0.5,
                 nhidden=128, nlayers=1, bidirectional=True):
        super(RNN_ENCODER, self).__init__()
        self.n_steps = cfg.TEXT.WORDS_NUM
        self.ntoken = ntoken  # size of the dictionary
        self.ninput = ninput  # size of each embedding vector
        self.drop_prob = drop_prob  # probability of an element to be zeroed
        self.nlayers = nlayers  # Number of recurrent layers
        self.bidirectional = bidirectional
        self.rnn_type = cfg.RNN_TYPE
        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        # number of features in the hidden state
        self.nhidden = nhidden // self.num_directions

        self.define_module()
        self.init_weights()

    def define_module(self):
        self.encoder = nn.Embedding(self.ntoken, self.ninput)
        self.drop = nn.Dropout(self.drop_prob)
        if self.rnn_type == 'LSTM':
            # dropout: If non-zero, introduces a dropout layer on
            # the outputs of each RNN layer except the last layer
            self.rnn = nn.LSTM(self.ninput, self.nhidden,
                               self.nlayers, batch_first=True,
                               dropout=self.drop_prob,
                               bidirectional=self.bidirectional)
        elif self.rnn_type == 'GRU':
            self.rnn = nn.GRU(self.ninput, self.nhidden,
                              self.nlayers, batch_first=True,
                              dropout=self.drop_prob,
                              bidirectional=self.bidirectional)
        else:
            raise NotImplementedError

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        # Do not need to initialize RNN parameters, which have been initialized
        # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM
        # self.decoder.weight.data.uniform_(-initrange, initrange)
        # self.decoder.bias.data.fill_(0)

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_()),
                    Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_()))
        else:
            return Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_())

    def forward(self, captions, cap_lens, hidden, mask=None):
        # input: torch.LongTensor of size batch x n_steps
        # --> emb: batch x n_steps x ninput
        emb = self.drop(self.encoder(captions))
        #
        # Returns: a PackedSequence object
        cap_lens = cap_lens.data.tolist()
        emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
        # #hidden and memory (num_layers * num_directions, batch, hidden_size):
        # tensor containing the initial hidden state for each element in batch.
        # #output (batch, seq_len, hidden_size * num_directions)
        # #or a PackedSequence object:
        # tensor containing output features (h_t) from the last layer of RNN
        output, hidden = self.rnn(emb, hidden)
        # PackedSequence object
        # --> (batch, seq_len, hidden_size * num_directions)
        output = pad_packed_sequence(output, batch_first=True)[0]
        # output = self.drop(output)
        # --> batch x hidden_size*num_directions x seq_len
        words_emb = output.transpose(1, 2)
        # --> batch x num_directions*hidden_size
        if self.rnn_type == 'LSTM':
            sent_emb = hidden[0].transpose(0, 1).contiguous()
        else:
            sent_emb = hidden.transpose(0, 1).contiguous()
        sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
        return words_emb, sent_emb


class CNN_ENCODER(nn.Module):
    def __init__(self, nef):
        super(CNN_ENCODER, self).__init__()
        if cfg.TRAIN.FLAG:
            self.nef = nef
        else:
            self.nef = 256  # define a uniform ranker

        model = models.inception_v3()
        url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
        model.load_state_dict(model_zoo.load_url(url))
        for param in model.parameters():
            param.requires_grad = False
        print('Load pretrained model from ', url)
        # print(model)

        self.define_module(model)
        self.init_trainable_weights()

    def define_module(self, model):
        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
        self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
        self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
        self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
        self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
        self.Mixed_5b = model.Mixed_5b
        self.Mixed_5c = model.Mixed_5c
        self.Mixed_5d = model.Mixed_5d
        self.Mixed_6a = model.Mixed_6a
        self.Mixed_6b = model.Mixed_6b
        self.Mixed_6c = model.Mixed_6c
        self.Mixed_6d = model.Mixed_6d
        self.Mixed_6e = model.Mixed_6e
        self.Mixed_7a = model.Mixed_7a
        self.Mixed_7b = model.Mixed_7b
        self.Mixed_7c = model.Mixed_7c

        self.emb_features = conv1x1(768, self.nef)
        self.emb_cnn_code = nn.Linear(2048, self.nef)

    def init_trainable_weights(self):
        initrange = 0.1
        self.emb_features.weight.data.uniform_(-initrange, initrange)
        self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        features = None
        # --> fixed-size input: batch x 3 x 299 x 299
        x = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=True)(x)
        # 299 x 299 x 3
        x = self.Conv2d_1a_3x3(x)
        # 149 x 149 x 32
        x = self.Conv2d_2a_3x3(x)
        # 147 x 147 x 32
        x = self.Conv2d_2b_3x3(x)
        # 147 x 147 x 64
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 73 x 73 x 64
        x = self.Conv2d_3b_1x1(x)
        # 73 x 73 x 80
        x = self.Conv2d_4a_3x3(x)
        # 71 x 71 x 192

        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 35 x 35 x 192
        x = self.Mixed_5b(x)
        # 35 x 35 x 256
        x = self.Mixed_5c(x)
        # 35 x 35 x 288
        x = self.Mixed_5d(x)
        # 35 x 35 x 288

        x = self.Mixed_6a(x)
        # 17 x 17 x 768
        x = self.Mixed_6b(x)
        # 17 x 17 x 768
        x = self.Mixed_6c(x)
        # 17 x 17 x 768
        x = self.Mixed_6d(x)
        # 17 x 17 x 768
        x = self.Mixed_6e(x)
        # 17 x 17 x 768

        # image region features
        features = x
        # 17 x 17 x 768

        x = self.Mixed_7a(x)
        # 8 x 8 x 1280
        x = self.Mixed_7b(x)
        # 8 x 8 x 2048
        x = self.Mixed_7c(x)
        # 8 x 8 x 2048
        x = F.avg_pool2d(x, kernel_size=8)
        # 1 x 1 x 2048
        # x = F.dropout(x, training=self.training)
        # 1 x 1 x 2048
        x = x.view(x.size(0), -1)
        # 2048

        # global image features
        cnn_code = self.emb_cnn_code(x)
        # 512
        if features is not None:
            features = self.emb_features(features)
        return features, cnn_code


# ############## G networks ###################
class CA_NET(nn.Module):
    # some code is modified from vae examples
    # (https://github.com/pytorch/examples/blob/master/vae/main.py)
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = cfg.TEXT.EMBEDDING_DIM
        self.c_dim = cfg.GAN.CONDITION_DIM
        self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
        self.relu = GLU()
        self.device = cfg.GPU_ID

    def encode(self, text_embedding):
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if cfg.CUDA:
            eps = torch.cuda.FloatTensor(std.size()).normal_()
            eps = eps.to(self.device)
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def forward(self, text_embedding):
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar

class FUSE_WORD_IMAGE_MAPS(nn.Module):
    def __init__(self, filter_dim):
        super(FUSE_WORD_IMAGE_MAPS, self).__init__()

        self.fusion_gate = nn.Sequential(
                nn.Conv2d(filter_dim*2, 1, kernel_size=1, stride=1, padding=0),
                nn.Sigmoid()
                )
    
    def forward(self, ngram_image_maps):
        seq_len = ngram_image_maps.size()[1]
        sent_map = ngram_image_maps[:,0,:,:,:].contiguous()

        for s in range(1,seq_len):
            next_ngram_map = ngram_image_maps[:,s,:,:,:].contiguous()
            fusion_gate = self.fusion_gate(torch.cat((sent_map, next_ngram_map), 1))
            sent_map = sent_map * (1-fusion_gate) + next_ngram_map * fusion_gate  

        return sent_map


class INIT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(INIT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.in_dim = cfg.GAN.Z_DIM + ncf  
        self.ef_dim = nef
        self.w_dim = self.ef_dim + self.in_dim
        self.define_module()

    def define_module(self):
        nz, ngf = self.in_dim, self.gf_dim

        self.conv1d = nn.Sequential(
            nn.Conv1d(self.w_dim, ngf * 4 * 4 * 2, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm1d(ngf * 4 * 4 * 2),
            GLU())

        self.upsample1 = upBlock(ngf, ngf // 2)
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        self.upsample3 = upBlock(ngf // 4, ngf // 8)
        self.upsample4 = upBlock(ngf // 8, ngf // 16)

        self.fuse_word_image_maps = FUSE_WORD_IMAGE_MAPS(ngf // 16)

    def forward(self, z_code, c_code, word_embs, mask=None):
        """
        :param z_code: batch x cfg.GAN.Z_DIM
        :param c_code: batch x cfg.TEXT.EMBEDDING_DIM
        :return: batch x ngf/16 x 64 x 64
        """
        batch_size = c_code.size()[0]

        c_z_code = torch.cat((c_code, z_code), 1)
        
        seq_len = word_embs.size()[-1]
        c_z_code = c_z_code.unsqueeze(-1).repeat(1,1, seq_len)
        
        w_sent_out_code = torch.cat((word_embs, c_z_code),1)

        # (batch_size x seq_len x (ngf*4*4))
        out_code = self.conv1d(w_sent_out_code).permute(0,2,1).contiguous() 
        out_code = out_code.view(-1, self.gf_dim, 4, 4)

        # state size ngf/3 x 8 x 8
        out_code = self.upsample1(out_code)
        # state size ngf/4 x 16 x 16
        out_code = self.upsample2(out_code)
        # state size ngf/8 x 32 x 32
        out_code32 = self.upsample3(out_code)
        # state size ngf/16 x 64 x 64
        out_code64 = self.upsample4(out_code32)

        _, filter_dim, height, width = out_code64.size()

        out_code64 = out_code64.view(batch_size, -1, filter_dim, height, width)

        out = self.fuse_word_image_maps(out_code64)

        return out


class Memory(nn.Module):
    def __init__(self):
        super(Memory, self).__init__()
        self.sm = nn.Softmax()
        self.mask = None

    def applyMask(self, mask):
        self.mask = mask  # batch x sourceL

    def forward(self, global_query, grid_queries, local_queries, context_key, content_value, size):#
        """
            input: batch x idf x ih x iw (queryL=ihxiw)
            context: batch x idf x sourceL
        """
        
        # u*u = Number of pixels in a grid cell        
        u  = int(size/8)

        batch_size, sourceL = context_key.size(0), context_key.size(-1)
       
        # --> batch x queryL x idf
        global_query = global_query.view(batch_size, global_query.size()[1],1).repeat(1, 1, u*u)
        
        weight = torch.zeros(batch_size, u*u, 8, 8, sourceL).cuda()
        for x in range(8):
            for y in range(8): 
                # for each grid-cell 

                # Extract region and grid-level queries.
                region_query = local_queries[:,:,x*u:(x+1)*u, y*u:(y+1)*u].contiguous().view(batch_size, local_queries.size()[1], -1)
                grid_query = grid_queries[:,:,x,y].contiguous().view(batch_size, grid_queries.size()[1],1).repeat(1, 1, u*u)
                
                # Concatenate three query components.
                region_query_cat = torch.cat((global_query, grid_query, region_query),1).transpose(1,2)
                
                # Get attention weights by key addressing for the grid-cell.
                region_ck = context_key[:,:,x,y,:].contiguous().view(batch_size, context_key.size()[1], -1)
                region_weight = torch.bmm(region_query_cat, region_ck)

                # Save weights for later use.
                weight[:,:,x,y,:] = region_weight

        weight = weight.contiguous()

        # Get weight
        # (batch x queryL x idf)(batch x idf x sourceL)-->batch x queryL x sourceL
       
        # --> batch*(pixels in grid cell) x sourceL
        weight = weight.view(batch_size*u*u*8*8, sourceL)
        if self.mask is not None:
            # batch_size x sourceL --> batch_size*queryL x sourceL
            mask = self.mask.repeat(u*u*8*8 , 1)
            weight.data.masked_fill_(mask.data, -float('inf'))

        # Pass attention weights through a softmax and shape them in an appropriate way for efficent matrix multiplication in value reading. 
        weight = torch.nn.functional.softmax(weight, dim=1).view(batch_size, u*u, 8, 8, sourceL).permute(0,2,3,1,4).contiguous()
        weight = weight.view(batch_size *8*8, u*u, sourceL).transpose(1,2)
        content_value = content_value.permute(0,2,3,1,4).contiguous()
        content_value = content_value.view(batch_size *8*8,  content_value.size()[3], sourceL)

       
        # Value reading, and reshaping matrices to return a set of refinement features.
        weightedContext = torch.bmm(content_value, weight)  #
        weightedContext = weightedContext.view(batch_size, 8,8, weightedContext.size()[1], u, u)
        weightedContext = weightedContext.permute(0,3,1,2,4,5).contiguous()
        weightedContext = weightedContext.view(batch_size, weightedContext.size()[1], 8 * 8, u * u)        
        weightedContext = weightedContext.view(batch_size, weightedContext.size()[1], 8 * 8 * u * u)        
        weightedContext = weightedContext.view(batch_size, weightedContext.size()[1], 8*u , 8*u)
        
        return weightedContext


class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf, size):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.z_dim = cfg.GAN.Z_DIM
        self.cf_dim = ncf
        self.num_residual = cfg.GAN.R_NUM
        self.size = size
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    
    def define_module(self):
        ngf = self.gf_dim
        self.avg = nn.AvgPool2d(kernel_size=self.size)
        self.spatial_avg = nn.AvgPool2d(kernel_size=int(self.size/8))
        
        self.sigmoid = nn.Sigmoid()
        
        self.memory_operation = Memory()
        
        self.preservation_gate = nn.Sequential(
            nn.Conv2d(self.gf_dim * 2, 1, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
            )
        
        self.residual = self._make_layer(ResBlock, ngf * 2)
        self.upsample = upBlock(ngf * 2, ngf)
        
        self.cos = torch.nn.CosineSimilarity()
        
        self.head_encoder = torch.nn.ModuleDict()
        self.A = torch.nn.ModuleDict()
        self.B = torch.nn.ModuleDict()
        self.M_r = torch.nn.ModuleDict()
        self.M_w = torch.nn.ModuleDict()
        self.key = torch.nn.ModuleDict()
        self.value = torch.nn.ModuleDict()
        self.global_query = torch.nn.ModuleDict() 
        self.grid_queries = torch.nn.ModuleDict()
        self.local_queries = torch.nn.ModuleDict() 
        self.response_gate = torch.nn.ModuleDict()
        
        # Create separate parameters for each refinement head.
        for i in range(6):
            i = str(i)
            self.A[i] = nn.Linear(self.ef_dim, 8*8, bias=False)
            self.B[i] = torch.nn.ModuleDict()

            for x in range(8):
                x=str(x)
                self.B[i][x] = torch.nn.ModuleDict()
                for y in range(8):
                    y=str(y)
                    self.B[i][x][y] = nn.Linear(ngf, 1, bias=False)
            
            self.head_encoder[i] = nn.Sequential(
                conv3x3(self.gf_dim, ngf),
                nn.Tanh()
                )
            
            self.M_r[i] = nn.Sequential(
                conv3x3(self.gf_dim, ngf*2),
                nn.ReLU()
                )
            self.M_w[i] = nn.Sequential(
                nn.Conv1d(self.ef_dim, ngf * 2, kernel_size=1, stride=1, padding=0),
                nn.ReLU()
                )
            
            self.key[i] = nn.Sequential(
                conv3x3(self.gf_dim*2, self.gf_dim*3),
                nn.ReLU()
                )
        
            self.value[i] = nn.Sequential(
                conv3x3(self.gf_dim*2, self.gf_dim),
                nn.ReLU()
                )
            
            self.global_query[i] = nn.Sequential(
                conv3x3(self.gf_dim, self.gf_dim ,stride=1),
                nn.ReLU()
                )

            self.grid_queries[i] = nn.Sequential(
                conv3x3(self.gf_dim, self.gf_dim ,stride=1),
                nn.ReLU()
                )

            self.local_queries[i] = nn.Sequential(
                conv3x3(self.gf_dim, self.gf_dim ,stride=1),
                nn.ReLU()
                )
            
            self.response_gate[i] = nn.Sequential(
                nn.Conv2d(self.gf_dim * 2, 1, kernel_size=1, stride=1, padding=0),
                nn.Sigmoid()
                )
            
    def forward(self, h_code, z_code, c_code, word_embs, mask, cap_lens):
        """
            h_code(image features):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(word features): batch x cdf x sourceL (sourceL=seq_len)
            c_code: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        # Memory Writing
        word_embs_T = torch.transpose(word_embs, 1, 2).contiguous()
        batch_size = word_embs.size()[0]
        memory = []

        # Memory creation for each refinement head.
        for i in range(6):
            i = str(i)
            head = self.head_encoder[i](h_code)
            head_for_memory = self.spatial_avg(head).detach()
            
            head = self.avg(head).detach()            
            head = head.squeeze(3)
            head_T = torch.transpose(head, 1, 2).contiguous()
            gate1 = torch.transpose(self.A[i](word_embs_T), 1, 2).contiguous()
        
            gate2 = torch.zeros(batch_size,1,8,8).cuda()
            for x in range(8):
                for y in range(8): 
                    gate2[:,:,x,y] += torch.transpose(self.B[i][str(x)][str(y)](head_for_memory[:,:,x,y].unsqueeze(1)),1,2).contiguous().squeeze(1) 
            
            gate2 = gate2.unsqueeze(-1).repeat(1,1,1,1, word_embs.size(2))

            batch_size = gate1.size()[0]
            gate1 = gate1.view(batch_size, 8,8, word_embs.size(2)).unsqueeze(1)


            writing_gate = torch.sigmoid(gate1 + gate2)
            
            encoded_words = self.M_w[i](word_embs).unsqueeze(-1).unsqueeze(-1)
            encoded_img = self.M_r[i](head_for_memory).unsqueeze(-1)
            encoded_img = encoded_img.repeat(1, 1, 1, 1, encoded_words.size(2))
 
            encoded_words = encoded_words.repeat(1, 1, 1, 8, 8)
            encoded_words = encoded_words.permute(0, 1, 3, 4, 2) 
            writing_gate = writing_gate.repeat(1, self.gf_dim*2, 1, 1, 1)
 
            memory.append(encoded_words * writing_gate + encoded_img * (1 - writing_gate))

       
        self.memory_operation.applyMask(mask)
        
        # Iterative Key Addressing and Value Reading for each refinement head.
        u = int(self.size/8)
        prev_h_code = h_code.clone()
        avg_mem_out = []
        for i in range(6):
            j = str(i)
            key = torch.zeros(memory[i].size()[0], self.gf_dim*3, 8, 8, memory[i].size()[-1]).cuda()
            value = torch.zeros(memory[i].size()[0], self.gf_dim, 8, 8, memory[i].size()[-1]).cuda()
            
            for p in range(memory[i].size()[-1]):
                key[:,:,:,:,p] = self.key[j](memory[i][:,:,:,:,p])
                value[:,:,:,:,p]  = self.value[j](memory[i][:,:,:,:,p])
            
            # Create query matrics
            global_query = self.avg(self.global_query[j](h_code))
            grid_queries = self.spatial_avg(self.grid_queries[j](h_code))
            local_queries = self.local_queries[j](h_code)

            # Extract refinement features from spatial dynamic memory
            self.memory_operation.applyMask(mask)
            memory_out = self.memory_operation(global_query, grid_queries, local_queries, key, value, size=self.size)
            
            avg_mem_out.append(self.avg(h_code).squeeze())
            
            # Fuse refinement features
            response_gate = self.response_gate[j](torch.cat((h_code, memory_out), 1))
            h_code = h_code * (1 - response_gate) + response_gate * memory_out
        
        # Add skip connection.
        preservation_gate = self.preservation_gate(torch.cat((h_code, prev_h_code), 1))
        h_code_new = h_code * (1 - preservation_gate) + preservation_gate * prev_h_code


        h_code_new = torch.cat((h_code_new, h_code_new), 1)
        out_code = self.residual(h_code_new)
        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)
        
        # Compute Redundancy Loss
        redundancy_loss = 0
        for i in range(6):
            for j in range(i,6):
                if i!=j:
                    pairwise_similarity = self.cos(avg_mem_out[i], avg_mem_out[j])
                    redundancy_loss += pairwise_similarity
        
        return out_code, redundancy_loss


class GET_IMAGE_G(nn.Module):
    def __init__(self, ngf):
        super(GET_IMAGE_G, self).__init__()
        self.gf_dim = ngf
        self.img = nn.Sequential(
            conv3x3(ngf, 3),
            nn.Tanh()
        )

    def forward(self, h_code):
        out_img = self.img(h_code)
        return out_img
    
class Encode_IMAGE_G(nn.Module):
    def __init__(self, ngf):
        super(Encode_IMAGE_G, self).__init__()
        self.gf_dim = ngf
        self.img = nn.Sequential(
            conv3x3(ngf, ngf),
            nn.Tanh()
        )

    def forward(self, h_code):
        out_img = self.img(h_code)
        return out_img

class G_NET(nn.Module):
    def __init__(self):
        super(G_NET, self).__init__()
        ngf = cfg.GAN.GF_DIM
        nef = cfg.TEXT.EMBEDDING_DIM
        ncf = cfg.GAN.CONDITION_DIM
        self.ca_net = CA_NET()

        if cfg.TREE.BRANCH_NUM > 0:
            self.h_net1 = INIT_STAGE_G(ngf * 16, nef, ncf)
            self.img_net1 = GET_IMAGE_G(ngf)
        # gf x 64 x 64
        if cfg.TREE.BRANCH_NUM > 1:
            self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf, 64)
            self.img_net2 = GET_IMAGE_G(ngf)
        if cfg.TREE.BRANCH_NUM > 2:
            self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf, 128)
            self.img_net3 = GET_IMAGE_G(ngf)

    def forward(self, z_code, sent_emb, word_embs, mask, cap_lens):
        """
            :param z_code: batch x cfg.GAN.Z_DIM
            :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
            :param word_embs: batch x cdf x seq_len
            :param mask: batch x seq_len
            :return:
        """
        fake_imgs = []
        att_maps = []
        c_code, mu, logvar = self.ca_net(sent_emb)
        redundancy_loss = 0

        if cfg.TREE.BRANCH_NUM > 0:
            h_code1 = self.h_net1(z_code, c_code, word_embs, mask)
            fake_img1 = self.img_net1(h_code1)
            fake_imgs.append(fake_img1)
        if cfg.TREE.BRANCH_NUM > 1:
            h_code2, redundancy_loss2 = self.h_net2(h_code1, z_code, c_code, word_embs, mask, cap_lens)
            fake_img2 = self.img_net2(h_code2)
            fake_imgs.append(fake_img2)
           
        if cfg.TREE.BRANCH_NUM > 2:
            h_code3, redundancy_loss3 = self.h_net3(h_code2, z_code, c_code, word_embs, mask, cap_lens)
            fake_img3 = self.img_net3(h_code3)
            fake_imgs.append(fake_img3)
            
        redundancy_loss = torch.mean(redundancy_loss2 + redundancy_loss3)
            
        return fake_imgs, redundancy_loss, mu, logvar

# ############## D networks ##########################
def Block3x3_leakRelu(in_planes, out_planes):
    block = nn.Sequential(
        SpectralNorm(conv3x3(in_planes, out_planes, bias=True)),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return block


# Downsale the spatial size by a factor of 2
def downBlock(in_planes, out_planes):
    block = nn.Sequential(
        SpectralNorm(nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=True)),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return block

# Downsale the spatial size by a factor of 16
def encode_image_by_16times(ndf):
    layers = []
    layers.append(SpectralNorm(nn.Conv2d(3, ndf, 4, 2, 1, bias=True)))
    layers.append(nn.LeakyReLU(0.2, inplace=True),)
    layers.append(SpectralNorm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=True)))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    layers.append(SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=True)))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    layers.append(SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=True)))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

class D_GET_LOGITS(nn.Module):
    def __init__(self, ndf, nef):
        super(D_GET_LOGITS, self).__init__()
        self.df_dim = ndf
        self.ef_dim = nef
        self.joint_conv = nn.Sequential(
            nn.Conv2d(ndf * 16+nef, ndf * 2, 3, 1, 1, bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf * 2, 1, 4, 1, 0, bias=False),
        )

    def forward(self, out, y):
        y = y.view(-1, self.ef_dim, 1, 1)
        y = y.repeat(1, 1, 4, 4)
        h_c_code = torch.cat((out, y), 1)
        out = self.joint_conv(h_c_code)
        return out

class D_NET64(nn.Module):
    def __init__(self):
        super(D_NET64, self).__init__()
        ndf = cfg.GAN.DF_DIM
        nef = cfg.TEXT.EMBEDDING_DIM

        self.conv_img = nn.Conv2d(3, ndf, 3, 1, 1)#32
        self.block0 = resD(ndf * 1, ndf * 2)#16
        self.block1 = resD(ndf * 2, ndf * 4)#8
        self.block2 = resD(ndf * 4, ndf * 8)#4
        self.block3 = resD(ndf * 8, ndf * 16)#4

        self.COND_DNET = D_GET_LOGITS(ndf, nef)

    def forward(self,x):

        out = self.conv_img(x)
        out = self.block0(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)

        return out

class D_NET128(nn.Module):
    def __init__(self):
        super(D_NET128, self).__init__()
        ndf = cfg.GAN.DF_DIM
        nef = cfg.TEXT.EMBEDDING_DIM

        self.conv_img = nn.Conv2d(3, ndf, 3, 1, 1)#64
        self.block0 = resD(ndf * 1, ndf * 2)#32
        self.block1 = resD(ndf * 2, ndf * 4)#16
        self.block2 = resD(ndf * 4, ndf * 8)#8
        self.block3 = resD(ndf * 8, ndf * 16)#4
        self.block4 = resD(ndf * 16, ndf * 16)#4

        self.COND_DNET = D_GET_LOGITS(ndf, nef)

    def forward(self,x):

        out = self.conv_img(x)
        out = self.block0(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)

        return out


class D_NET256(nn.Module):
    def __init__(self):
        super(D_NET256, self).__init__()
        ndf = cfg.GAN.DF_DIM
        nef = cfg.TEXT.EMBEDDING_DIM

        self.conv_img = nn.Conv2d(3, ndf, 3, 1, 1)#128
        self.block0 = resD(ndf * 1, ndf * 2)#64
        self.block1 = resD(ndf * 2, ndf * 4)#32
        self.block2 = resD(ndf * 4, ndf * 8)#16
        self.block3 = resD(ndf * 8, ndf * 16)#8
        self.block4 = resD(ndf * 16, ndf * 16)#4
        self.block5 = resD(ndf * 16, ndf * 16)#4

        self.COND_DNET = D_GET_LOGITS(ndf, nef)

    def forward(self,x):

        out = self.conv_img(x)
        out = self.block0(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)

        return out

