# External imports
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

# Project imports
from nps.beam import Beam
from nps.reinforce import Rolls
from nps.data import IMG_SIZE
from nps.network import *

class Query1Dec(nn.Module): 
    '''
    query with lstm.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query1Dec, self).__init__()

        self.lstm_io = nn.LSTM(512, 512)
        self.decoder1 = nn.Sequential(
            nn.Linear(512, 128*8*8),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 16*16*16),
#            nn.Sigmoid(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.Sigmoid(),
        ), 3)

        self.io_emb_size = fc_stack[-1]
        self.lstm_input_size = self.io_emb_size
        self.lstm_hidden_size = 512 
        self.nb_layers = 1 
        self.rnn = nn.LSTM(
            self.lstm_input_size,
            self.lstm_hidden_size,
            self.nb_layers,
        )

        self.initial_h = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))
        self.initial_c = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))

        self.init_weights()
#        self.apply(self.init_weights)

    def init_weights(self):
        initrange = 0.1
        self.initial_h.data.uniform_(-initrange, initrange)
        self.initial_c.data.uniform_(-initrange, initrange)

#    def _init_weights(self, module):
#        if isinstance(module, (nn.Linear, nn.Embedding)):
#            module.weight.data.normal_(mean=0.0, std=0.02)
#            if isinstance(module, nn.Linear) and module.bias is not None:
#                module.bias.data.zero_()
#        elif classname.find('Conv') != -1:
#            nn.init.normal_(m.weight.data, 0.0, 0.02)
#        elif classname.find('BatchNorm') != -1:
#            nn.init.normal_(m.weight.data, 1.0, 0.02)
#            nn.init.constant_(m.bias.data, 0)

    def my_hook(self, module, grad_input, grad_output):
#        print('doing my_hook')
#        print('original grad:', grad_input[0].shape, grad_input)
#        print('original outgrad:', grad_output[0].shape, grad_output)
        # grad_input = grad_input[0]*self.input   # 这里把hook函数内对grad_input的操作进行了注释，
        # grad_input = tuple([grad_input])        # 返回的grad_input必须是tuple，所以我们进行了tuple包装。
        # print('now grad:', grad_input)        

        return grad_input

    def decode_process(self, joint_emb, input_grids, output_grids):
#        inp_emb = self.in_grid_enc(input_grids)
#        out_emb = self.out_grid_enc(output_grids)
#        # {inp, out}_emb: batch_size x nb_ios x feats x height x width
#
#        io_emb = torch.cat([inp_emb, out_emb], 2)
#        # io_emb: batch_size x nb_ios x 2 * feats x height x width
#        joint_emb = self.joint_enc(io_emb)
        batch_size, nb_ios = input_grids.size()[:2]

        #joint_emb : batch_size x nb_ios x feats 
        joint_emb = joint_emb.permute(1, 0, 2).contiguous()

        lstm_cell_size = torch.Size((self.nb_layers, batch_size, self.lstm_hidden_size))
        initial_state = (
            self.initial_h.expand(lstm_cell_size).contiguous(),
            self.initial_c.expand(lstm_cell_size).contiguous()
        )
        initial_state = (
            initial_state[0].view(self.nb_layers, batch_size, self.lstm_hidden_size),
            initial_state[1].view(self.nb_layers, batch_size, self.lstm_hidden_size)
        )
        _, lstm_state = self.rnn(joint_emb, initial_state)

        emb = lstm_state[0].squeeze(0)
#        emb = emb.permute(1, 0, 2).contiguous()
#        emb = emb.mean(1)
        out = self.decoder1(emb)
        out = out.view(out.shape[0], 1, 128, 8, 8)
        out = self.decoder2(out)
#        out = out.view(out.shape[0], 1, 16, 16, 16)

        tmp = (torch.ones((out.shape[0], out.shape[1], 1, out.shape[3], out.shape[4]))*0.5).cuda()
        out = torch.cat([out, tmp], 2)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard
    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query2Dec(nn.Module): 
    '''
    query with mean max.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query2Dec, self).__init__()

        self.decoder1 = nn.Sequential(
            nn.Linear(512, 128*8*8),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 16*16*16),
#            nn.Sigmoid(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.Sigmoid(),
        ), 3)

#        self.io_emb_size = fc_stack[-1]
#        self.lstm_input_size = self.io_emb_size
#        self.lstm_hidden_size = 512 
#        self.nb_layers = 1 
#        self.rnn = nn.LSTM(
#            self.lstm_input_size,
#            self.lstm_hidden_size,
#            self.nb_layers,
#        )

#        self.initial_h = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))
#        self.initial_c = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))

        self.init_weights()
#        self.apply(self.init_weights)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        #joint_emb : batch_size x nb_ios x feats 
#        joint_emb = joint_emb.permute(1, 0, 2).contiguous()

#        lstm_cell_size = torch.Size((self.nb_layers, batch_size, self.lstm_hidden_size))
#        initial_state = (
#            self.initial_h.expand(lstm_cell_size).contiguous(),
#            self.initial_c.expand(lstm_cell_size).contiguous()
#        )
#        initial_state = (
#            initial_state[0].view(self.nb_layers, batch_size, self.lstm_hidden_size),
#            initial_state[1].view(self.nb_layers, batch_size, self.lstm_hidden_size)
#        )
#        _, lstm_state = self.rnn(joint_emb, initial_state)

#        emb = lstm_state[0].squeeze(0)
#        emb = emb.permute(1, 0, 2).contiguous()
#        emb = emb.mean(1)
#        joint_emb, _ = joint_emb.max(1)
        out = self.decoder1(joint_emb)
#        out, _ = out.max(1)
        out = out.mean(1)
        out = out.view(out.shape[0], -1, 128, 8, 8)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, 16, 16, 16)
#        out = torch.rand(input_grids.shape[0], 1, 16, 16, 16).cuda()

        tmp = (torch.ones((out.shape[0], out.shape[1], 1, out.shape[3], out.shape[4]))*0.5).cuda()
        out = torch.cat([out, tmp], 2)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query3Dec(nn.Module): 
    '''
    soft softmax.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query3Dec, self).__init__()

        self.decoder1 = nn.Sequential(
            nn.Linear(512, 128*8*8),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 16*16*16),
#            nn.Sigmoid(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.Sigmoid(),
        ), 3)

#        self.io_emb_size = fc_stack[-1]
#        self.lstm_input_size = self.io_emb_size
#        self.lstm_hidden_size = 512 
#        self.nb_layers = 1 
#        self.rnn = nn.LSTM(
#            self.lstm_input_size,
#            self.lstm_hidden_size,
#            self.nb_layers,
#        )

#        self.initial_h = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))
#        self.initial_c = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))

        self.init_weights()
#        self.apply(self.init_weights)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        #joint_emb : batch_size x nb_ios x feats 
#        joint_emb = joint_emb.permute(1, 0, 2).contiguous()

#        lstm_cell_size = torch.Size((self.nb_layers, batch_size, self.lstm_hidden_size))
#        initial_state = (
#            self.initial_h.expand(lstm_cell_size).contiguous(),
#            self.initial_c.expand(lstm_cell_size).contiguous()
#        )
#        initial_state = (
#            initial_state[0].view(self.nb_layers, batch_size, self.lstm_hidden_size),
#            initial_state[1].view(self.nb_layers, batch_size, self.lstm_hidden_size)
#        )
#        _, lstm_state = self.rnn(joint_emb, initial_state)

#        emb = lstm_state[0].squeeze(0)
#        emb = emb.permute(1, 0, 2).contiguous()
#        emb = emb.mean(1)
#        joint_emb, _ = joint_emb.max(1)
        out = self.decoder1(joint_emb)
        out, _ = out.max(1)
        out = out.view(out.shape[0], -1, 128, 8, 8)
        out = self.decoder2(out)
#        out, _ = out.max(1)
        out = out.view(out.shape[0], -1, 16, 16, 16)

        tmp = (torch.ones((out.shape[0], out.shape[1], 1, out.shape[3], out.shape[4]))*0.5).cuda()
        out = torch.cat([out, tmp], 2)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        hero = hero.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        out = out[:,:,:-1,:,:]

        out = torch.cat([hero, out], 2)

        out = F.pad(out, (1,1,1,1))

        out[:,:,5,:,0] = 1
        out[:,:,5,0,:] = 1
        out[:,:,5,:,-1] = 1
        out[:,:,5,-1,:] = 1

        return out
    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

    def hard(self, out):
        tmp = (torch.ones((out.shape[0], out.shape[1], 1, out.shape[3], out.shape[4]))*0.5).cuda()
        out = torch.cat([out, tmp], 2)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = hero_hard.detach()
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

        out = out[:, :, 4:, :, :]
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = out_hard.detach()
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        return out_hard


class Query4Dec(nn.Module): 
    '''
    query with 2 heads.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query4Dec, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.Linear(512, 128*8*8),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 16*16*16),
#            nn.Sigmoid(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
        ), 3)

        self.hero = MapModule(nn.Sequential(
            nn.Linear(32*16*16, 4*16*16),
        ), 3)

        self.map = MapModule(nn.Sequential(
            nn.Conv2d(32, 13, 3, padding=1),
        ), 3)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        #joint_emb : batch_size x nb_ios x feats 
#        joint_emb = self.encoder(input_grids, output_grids)

        joint_emb, _ = joint_emb.max(1)
#        joint_emb = joint_emb.mean(1)
        out = self.decoder1(joint_emb)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, 128, 8, 8)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        map = self.map(out)

        out = out.view(-1, 32*16*16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 13, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query5Dec(nn.Module): 
    '''
    mean max without sigmoid.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query5Dec, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.Linear(512, 128*8*8),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 128*8*8),
#            nn.ReLU(inplace=True),
#            nn.Linear(128*8*8, 16*16*16),
#            nn.Sigmoid(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 17, 3, padding=1),
#            nn.Sigmoid(),
        ), 3)

#        self.io_emb_size = fc_stack[-1]
#        self.lstm_input_size = self.io_emb_size
#        self.lstm_hidden_size = 512 
#        self.nb_layers = 1 
#        self.rnn = nn.LSTM(
#            self.lstm_input_size,
#            self.lstm_hidden_size,
#            self.nb_layers,
#        )

#        self.initial_h = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))
#        self.initial_c = nn.Parameter(torch.Tensor(self.nb_layers, 1, self.lstm_hidden_size))
        self.init_weights()
#        self.apply(self.init_weights)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        #joint_emb : batch_size x nb_ios x feats 
#        joint_emb = self.encoder(input_grids, output_grids)

#        joint_emb = joint_emb.permute(1, 0, 2).contiguous()

#        lstm_cell_size = torch.Size((self.nb_layers, batch_size, self.lstm_hidden_size))
#        initial_state = (
#            self.initial_h.expand(lstm_cell_size).contiguous(),
#            self.initial_c.expand(lstm_cell_size).contiguous()
#        )
#        initial_state = (
#            initial_state[0].view(self.nb_layers, batch_size, self.lstm_hidden_size),
#            initial_state[1].view(self.nb_layers, batch_size, self.lstm_hidden_size)
#        )
#        _, lstm_state = self.rnn(joint_emb, initial_state)

#        emb = lstm_state[0].squeeze(0)
#        emb = emb.permute(1, 0, 2).contiguous()
#        emb = emb.mean(1)
#        joint_emb, _ = joint_emb.max(1)
        out = self.decoder1(joint_emb)
#        out, _ = out.max(1)
        out = out.mean(1)
        out = out.view(out.shape[0], -1, 128, 8, 8)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, 17, 16, 16)
#        out.uniform_(0, 1)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1
        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query6Dec(nn.Module): 
    '''
    decoder is the inversed encoder.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query6Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]


        self.decoder1 = nn.Sequential(
            nn.Linear(512, int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.in_grid_dec = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0] + 1,
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.ReLU(inplace=True)
        ), 3)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        out, _ = out.max(1)

        out = self.in_grid_dec(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, IMG_SIZE[0] + 1, IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
#        out.uniform_(0, 1)

        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1
        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query7Dec(nn.Module): 
    '''
    decoder is the inversed encoder, two heads.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query7Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]


        self.decoder1 = nn.Sequential(
            nn.Linear(512, int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.hero = MapModule(nn.Sequential(
            nn.Linear(32*16*16, 4*16*16),
        ), 3)

        self.map = MapModule(nn.Sequential(
            nn.Conv2d(32, 13, 3, padding=1),
        ), 3)


    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 13, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query8Dec(nn.Module): 
    '''
    decoder is the inversed encoder, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query8Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
            nn.ReLU(inplace=True)
        )
        self.decoder_noise = nn.Sequential(
            nn.Linear(2 * self.dec_dim, self.dec_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.in_grid_dec = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0] + 1,
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.ReLU(inplace=True)
        ), 3)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.normal(0, 1, (*out.shape[:-1], self.dec_dim)).cuda()
        out = torch.cat([out, noise], -1)
        out = self.decoder_noise(out)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        out, _ = out.max(1)

        out = self.in_grid_dec(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, IMG_SIZE[0] + 1, IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
#        out.uniform_(0, 1)


        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1
        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query9Dec(nn.Module): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query9Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
            nn.ReLU(inplace=True)
        )
        self.decoder_noise = nn.Sequential(
            nn.Linear(2 * self.dec_dim, self.dec_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.hero = MapModule(nn.Sequential(
            nn.Linear(32*16*16, 4*16*16),
        ), 3)

        self.map = MapModule(nn.Sequential(
            nn.Conv2d(32, 13, 3, padding=1),
        ), 3)


    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.normal(0, 1, (*out.shape[:-1], self.dec_dim)).cuda()
        out = torch.cat([out, noise], -1)
        out = self.decoder_noise(out)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)

        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 13, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query10Dec(nn.Module): 
    '''
    generation without constraint.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query10Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1]) * (IMG_SIZE[2]))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack),
        ), 3)
        self.decoder3 = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0],
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
        ), 3)


    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1], IMG_SIZE[2])
        out = self.decoder2(out)
#        print(out.shape)
        out, _ = out.max(1, keepdim=True)
#        print(out.shape)
#        out = out.mean(1)

        out = self.decoder3(out)
        out = F.sigmoid(out)
#        print(out.shape)

        return out

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query11Dec(nn.Module): 
    '''
    generation without constraint, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query11Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1]) * (IMG_SIZE[2]))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
            nn.ReLU(inplace=True)
        )
        self.decoder_noise = nn.Sequential(
            nn.Linear(2 * self.dec_dim, self.dec_dim),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack),
        ), 3)
        self.decoder3 = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0],
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.ReLU(inplace=True)
        ), 3)


    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.normal(0, 1, (*out.shape[:-1], self.dec_dim)).cuda()
        out = torch.cat([out, noise], -1)
        out = self.decoder_noise(out)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1], IMG_SIZE[2])
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)

        out, _ = out.max(1, keepdim=True)

        out = self.decoder3(out)

        out = F.sigmoid(out)

        return out

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query92Dec(nn.Module): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query92Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
#            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.hero = MapModule(nn.Sequential(
            nn.Linear(32*16*16, 4*16*16),
        ), 3)

        self.map = MapModule(nn.Sequential(
            nn.Conv2d(32, 13, 3, padding=1),
        ), 3)


    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.randn(out.size()).cuda()
        out += noise 
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)

        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 13, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query82Dec(nn.Module): 
    '''
    decoder is the inversed encoder, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query82Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
#            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.in_grid_dec = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0] + 1,
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.ReLU(inplace=True)
        ), 3)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.randn(out.size()).cuda()
        out += noise 
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        out, _ = out.max(1)

        out = self.in_grid_dec(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, IMG_SIZE[0] + 1, IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
#        out.uniform_(0, 1)


        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]

        out_hard = torch.cat([hero_hard, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1
        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query93Dec(Query92Dec): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query93Dec, self).__init__(kernel_size, conv_stack, fc_stack)
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

class Query83Dec(Query82Dec): 
    '''
    decoder is the inversed encoder, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query83Dec, self).__init__(kernel_size, conv_stack, fc_stack)

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)


class Query94Dec(nn.Module): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query94Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
#            nn.ReLU(inplace=True)
        )
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.hero = MapModule(nn.Sequential(
            nn.Linear(32*16*16, 4*16*16),
        ), 3)

        self.map = MapModule(nn.Sequential(
            nn.Conv2d(32, 12, 3, padding=1),
        ), 3)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.randn(out.size()).cuda()
        out += noise 
#        print('noise here')
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)

        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 12, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        zeros = torch.zeros(*map_hard.shape[0:2], 1, *map_hard.shape[-2:], device='cuda')
        out_hard = torch.cat([hero_hard, zeros, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query74Dec(Query94Dec): 
    '''
    decoder is the inversed encoder, two heads.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query74Dec, self).__init__(kernel_size, conv_stack, fc_stack)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

#        out = out.mean(1)
        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 12, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        map_hard = (map_hard - map).detach() + map
        map_hard = map_hard[:,:,:-1,:,:]

        zeros = torch.zeros(*map_hard.shape[0:2], 1, *map_hard.shape[-2:], device='cuda')
        out_hard = torch.cat([hero_hard, zeros, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out


class Query84Dec(nn.Module): 
    '''
    decoder is the inversed encoder, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query84Dec, self).__init__()
        ## Do one layer of convolution before stacking

        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        self.decoder1 = nn.Sequential(
            nn.Linear(512, self.dec_dim),
#            nn.ReLU(inplace=True)
        )

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

        self.in_grid_dec = MapModule(nn.Sequential(
            nn.Conv2d(int(initial_dim), IMG_SIZE[0],
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.ReLU(inplace=True)
        ), 3)


    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.randn(out.size()).cuda()
        out += noise 
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        out, _ = out.max(1)

        out = self.in_grid_dec(out)
#        out, _ = out.max(1)
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, IMG_SIZE[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
#        out.uniform_(0, 1)


        hero = out[:, :, :4, :, :].view(out.shape[0], out.shape[1], -1)
        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        hero_hard = (hero_hard - hero).detach() + hero
        hero_hard = hero_hard.view(out.shape[0], out.shape[1], 4, out.shape[3], out.shape[4])

#        out = torch.cat([hero_hard, out_hard[:, :, 4:, :, :]], 2)

        out = out[:, :, 4:, :, :]
        out = F.softmax(out, dim=2)
        index = out.max(2, keepdim=True)[1]
        out_hard = torch.zeros_like(out, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        out_hard = (out_hard - out).detach() + out
        out_hard = out_hard[:,:,:-1,:,:]


        zeros = torch.zeros(*out_hard.shape[0:2], 1, *out_hard.shape[-2:], device='cuda')
        out_hard = torch.cat([hero_hard, zeros, out_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1
        return out_hard

    def forward(self, joint_emb, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        out = self.decode_process(joint_emb, input_grids, output_grids)
        return out

class Query95Dec(Query94Dec): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query95Dec, self).__init__(kernel_size, conv_stack, fc_stack)
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

class Query75Dec(Query74Dec): 
    '''
    decoder is the inversed encoder, two heads.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query75Dec, self).__init__(kernel_size, conv_stack, fc_stack)
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

class Query95softDec(Query95Dec): 
    '''
    decoder is the inversed encoder, two heads, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query95softDec, self).__init__(kernel_size, conv_stack, fc_stack)

    def decode_process(self, joint_emb, input_grids, output_grids):

        batch_size, nb_ios = input_grids.size()[:2]

        out = self.decoder1(joint_emb)

        noise = torch.randn(out.size()).cuda()
        out += noise 
#        out = out.mean(1)
        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)
#        out, _ = out.max(1)
#        out = out.mean(1)

        out, _ = out.max(1)

        map = self.map(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero(out)

        map = map.view(out.shape[0], -1, 12, 16, 16)

        hero = F.softmax(hero, dim=-1)
        # hard
#        index = hero.max(-1, keepdim=True)[1]
#        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
#        hero_hard = (hero_hard - hero).detach() + hero
        # soft
        hero_hard = hero
        hero_hard = hero_hard.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        # hard
#        index = map.max(2, keepdim=True)[1]
#        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
#        map_hard = (map_hard - map).detach() + map
        # soft
        map_hard = map
        map_hard = map_hard[:,:,:-1,:,:]

        zeros = torch.zeros(*map_hard.shape[0:2], 1, *map_hard.shape[-2:], device='cuda')
        out_hard = torch.cat([hero_hard, zeros, map_hard], 2)

        out_hard = F.pad(out_hard, (1,1,1,1))

        out_hard[:,:,5,:,0] = 1
        out_hard[:,:,5,0,:] = 1
        out_hard[:,:,5,:,-1] = 1
        out_hard[:,:,5,-1,:] = 1

        return out_hard

class Query85Dec(Query84Dec): 
    '''
    decoder is the inversed encoder, noise.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(Query85Dec, self).__init__(kernel_size, conv_stack, fc_stack)

        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, inp_grids, out_grids):
        
        joint_emb = self.encoder(inp_grids, out_grids)
        
        output = self.decoder(joint_emb, inp_grids, out_grids)
        
        return output
        
