# External imports
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

# Project imports
from nps.beam import Beam
from nps.reinforce import Rolls
from nps.data import IMG_SIZE
from nps.network import *
from nps.query_dec import *

class Query(nn.Module):
    '''
    mean max without sigmoid.
    '''
    def __init__(self, kernel_size, conv_stack, fc_stack, arch='5'):
        super(Query, self).__init__()

        if arch.split('.')[-1] == '2':
            self.encoder = IOsEncoder2(kernel_size, conv_stack, fc_stack)
        else:
            self.encoder = IOsEncoder(kernel_size, conv_stack, fc_stack)
        if arch.split('.')[0] == '1':
            self.decoder = Query1Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '2':
            self.decoder = Query2Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '3':
            self.decoder = Query3Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '4':
            self.decoder = Query4Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '5':
            self.decoder = Query5Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '6':
            self.decoder = Query6Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '7':
            self.decoder = Query7Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '8':
            self.decoder = Query8Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '9':
            self.decoder = Query9Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '10':
            self.decoder = Query10Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '11':
            self.decoder = Query11Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '74':
            self.decoder = Query74Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '75':
            self.decoder = Query75Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '82':
            self.decoder = Query82Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '83':
            self.decoder = Query83Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '84':
            self.decoder = Query84Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '85':
            self.decoder = Query85Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '92':
            self.decoder = Query92Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '93':
            self.decoder = Query93Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '94':
            self.decoder = Query94Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '95':
            self.decoder = Query95Dec(kernel_size, conv_stack, fc_stack)
        elif arch.split('.')[0] == '95soft':
            self.decoder = Query95softDec(kernel_size, conv_stack, fc_stack)
        else:
            print('Error arch:', arch)

        self.init_weights()
#        self.apply(self.init_weights)

    def init_weights(self):
        initrange = 0.1
#        self.initial_h.data.uniform_(-initrange, initrange)
#        self.initial_c.data.uniform_(-initrange, initrange)

    def forward(self, input_grids, output_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''

        joint_emb = self.encoder(input_grids, output_grids)

        out = self.decoder(joint_emb, input_grids, output_grids)

        return out


class Discriminator(nn.Module):
    def __init__(self, kernel_size=3, conv_stack=[32, 32, 32], fc_stack=[512]):
        super(Discriminator, self).__init__()

        self.encoder = DisEncoder(kernel_size, conv_stack, fc_stack)
        self.decoder = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, input_grids):
        joint_emb = self.encoder(input_grids)

        validity = self.decoder(joint_emb)

        return validity

class DisEncoder(nn.Module):
    def __init__(self, kernel_size, conv_stack, fc_stack):
        super(DisEncoder, self).__init__()

        ## Do one layer of convolution before stacking

        # Deduce the size of the embedding for each grid
        initial_dim = conv_stack[0]  # Because we are going to get dim from I and dim from O

        # TODO: we know that our grids are mostly sparse, and only positive.
        # That means that a different initialisation might be more appropriate.
        self.in_grid_enc = MapModule(nn.Sequential(
            nn.Conv2d(IMG_SIZE[0], int(initial_dim),
                      kernel_size=kernel_size, padding=int((kernel_size -1)/2)),
            nn.LeakyReLU(inplace=True)
        ), 3)

        # Define the model that works on the stacking
        self.joint_enc = MapModule(nn.Sequential(
            GridEncoder(kernel_size, conv_stack, fc_stack)
        ), 3)

    def forward(self, input_grids):
        '''
        {input, output}_grids: batch_size x nb_ios x channels x height x width
        '''
        inp_emb = self.in_grid_enc(input_grids)
        # {inp, out}_emb: batch_size x nb_ios x feats x height x width

        # io_emb: batch_size x nb_ios x 2 * feats x height x width
        joint_emb = self.joint_enc(inp_emb)
        return joint_emb
