import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from model.encoder import * 
from model.model import BaseModel
from model.network import *

class QueryMI(BaseModel):
    """Returns a compiled model described in Appendix section C.

    Arguments:
        I (int): number of inputs in each program. input count is
            padded to I with null type and vals.
        E (int): embedding dimension
        M (int): number of examples per program. default 5.

    Returns:
        keras.model compiled keras model as described in Appendix
        section C
    """
    def __init__(self, vocab_size, input_dropout_p, dropout_p, rnn_cell, num_layers=2, num_att_layers=2, att_ff_size=256, z_size=10):
        super(QueryMI, self).__init__()
        kernel_size = params.kernel_size
        conv_stack = params.conv_stack
        fc_stack = params.fc_stack

        self.io_recon = params.io_recon
        self.program_recon = params.program_recon
        self.hidden_size = hidden_size = params.dense_output_size 
        self.num_layers = num_layers
        self.encoder_dense = IOsEncoder(kernel_size, conv_stack, fc_stack)
        # token: state, length: program length
        self.encoder_rnn = RNNEncoder(vocab_size, hidden_size // 2,
                                      input_dropout_p=input_dropout_p,
                                      dropout_p=dropout_p,
                                      n_layers=num_layers,
                                      bidirectional=False,
                                      rnn_cell=rnn_cell,
                                      variable_lengths=True)
        self.encoder_transformer = TransformerEncoder(vocab_size, hidden_size // 2,
                                              n_layers=num_layers,
                                              n_heads=8,
                                              pf_dim=hidden_size,
                                              dropout=dropout_p,
                                              max_length=74+2)
        self.encoder_mu_z = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_mu_t = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_logvar_z = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_logvar_t = nn.Linear(hidden_size, params.dist_dim)
        self.decoder_gen = MLP(hidden_size, att_ff_size, hidden_size,
                               num_layers=num_att_layers)
        self.decoder_z = nn.Linear(z_size, hidden_size)
        #TODO: MLP params
        self.attention_z = MLP(2*hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_z2 = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_t = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)

        self.intersection = Intersection(params.dist_dim)

        ###########################################################################
        initial_dim = conv_stack[0] // 2  # Because we are going to get dim from I and dim from O

        self.conv_stack_dec = [channel // 2 for channel in conv_stack]

        self.dec_dim = int(initial_dim * (IMG_SIZE[1] - 2) * (IMG_SIZE[2] - 2))

        ############ latent code #############
        if params.latent_code:
            if params.noise:
                self.decoder1 = nn.Sequential(
                    nn.Linear(hidden_size + 1 + 5, self.dec_dim),
                )
            else:
                self.decoder1 = nn.Sequential(
                    nn.Linear(hidden_size + 1, self.dec_dim),
                )
            self.latent_decoder =  MLP(hidden_size, hidden_size, params.query_num, num_layers=2)
            #self.noise_decoder =  MLP(hidden_size, hidden_size, 5, num_layers=2)
        else:
            self.decoder1 = nn.Sequential(
                nn.Linear(hidden_size, self.dec_dim),
            )
        #####################################
        self.decoder2 = MapModule(nn.Sequential(
            GridDecoder2(kernel_size, self.conv_stack_dec, fc_stack)
        ), 3)
        self.hero_head = MapModule(nn.Sequential(
            nn.Linear(32 * 16 * 16, 4 * 16 * 16),
        ), 3)
        self.map_head = MapModule(nn.Sequential(
            nn.Conv2d(32, 12, 3, padding=1),
        ), 3)

    def forward(self, x, typ, program, plengths, hard_softmax=False):
        io_features = self.encode_io(x, typ)
        program_features = self.encode_program(program, plengths)
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        query, index = self.decode_query(zs, io_features, typ, hard_softmax)
        return query, index
    
    def decode_process_bak(self, emb, hard_softmax=False):
        # TODO: does decoder1 and decoder2 needed?
        out = self.decoder1(emb)

        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        #out, _ = out.max(1)

        map = self.map_head(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero_head(out)

        map = map.view(out.shape[0], -1, 12, 16, 16)

        if params.gumbel:            
            hero = F.gumbel_softmax(hero, tau=1, hard=True, dim=-1)
        else:
            hero = F.softmax(hero, dim=-1)
            index = hero.max(-1, keepdim=True)[1]
            if hard_softmax:
                hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
                hero_hard = (hero_hard - hero).detach() + hero
                hero = hero_hard
        hero = hero.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        if params.gumbel:            
            map = F.gumbel_softmax(map, tau=1, hard=True, dim=2)
        else:
            map = F.softmax(map, dim=2)
            index = map.max(2, keepdim=True)[1]
            if hard_softmax:
                map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
                map_hard = (map_hard - map).detach() + map
                map = map_hard
        map = map[:,:,:-1,:,:]

        zeros = torch.zeros(*map.shape[0:2], 1, *map.shape[-2:], device='cuda')
        out = torch.cat([hero, zeros, map], 2)

        out = F.pad(out, (1,1,1,1))

        out[:,:,5,:,0] = 1
        out[:,:,5,0,:] = 1
        out[:,:,5,:,-1] = 1
        out[:,:,5,-1,:] = 1

        return out

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)

    def encode_program(self, program, plengths):
        _, encoder_hidden = self.encoder_rnn(
                program, plengths, None)
        if self.encoder_rnn.rnn_cell == nn.LSTM:
            encoder_hidden = encoder_hidden[0]

        # Pick the hidden vector from the top layer.
        encoder_hidden = encoder_hidden[-1:, :, :].squeeze(0)
        return encoder_hidden
    def encode_program2(self, program, plengths):
        encoder_hidden = self.encoder_transformer(program)
        return encoder_hidden

    def encode_io(self, input_grids, output_grids):
        return self.encoder_dense(input_grids, output_grids)

    def encode_into_z(self, io_features, program_features):
        together = torch.cat([io_features, program_features], -1)
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z(together)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        return mus, logvars
    def encode_into_z2(self, program_features):
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z2(program_features)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        #logvars = torch.ones(*mus.shape, device='cuda') * 0.1

        return mus, logvars

    def encode_into_t(self, io_features):
        attended_hiddens = self.attention_t(io_features)
        mus = self.encoder_mu_t(attended_hiddens)
        logvars = self.encoder_logvar_t(attended_hiddens)
        return mus, logvars

    def decode_query(self, zs, io_features, typ, hard_softmax):
        batch_size = zs.size(0)
        z_hiddens = self.decoder_z(zs)
        hiddens = self.decoder_gen(io_features + z_hiddens)

        #print('hiddens:', hiddens.shape)
        result = self.decode_process(hiddens, typ, hard_softmax) 
        return result
    def decode_query2(self, zs, io_features):
        batch_size = zs.size(0)
        hiddens = zs
        hiddens = self.decoder_z(hiddens)
        #hiddens = self.decoder_gen(io_features + hiddens)
        return hiddens 
    
    def reconstruct_inputs(self, io_features, program_features):
        recon_io_features = None
        recon_program_features = None
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        if self.io_recon:
            recon_io_features = self.io_reconstructor(zs)
        if self.program_recon:
            recon_program_features = self.program_reconstructor(zs)
        return recon_io_features, recon_program_features

    def predict_from_io(self, io, typ, hard_softmax=False):
        #TODO: predict from io
        io_features = self.encode_io(io, typ)
        mus, logvars = self.encode_into_t(io_features)
        ts = self.reparameterize(mus, logvars)
        query, index = self.decode_query(ts, io_features, typ, hard_softmax)
        return query, index
    
    def distance(self, dist1, dist2):
        pass

    def decode_process(self, emb, hard_softmax=False):
        # TODO: does decoder1 and decoder2 needed?
        out = self.decoder1(emb)

        out = out.view(out.shape[0], -1, self.conv_stack_dec[0], IMG_SIZE[1] - 2, IMG_SIZE[2] - 2)
        out = self.decoder2(out)

        #out, _ = out.max(1)

        map = self.map_head(out)

        out = out.view(-1, 32 * 16 * 16)
        hero = self.hero_head(out)

        map = map.view(out.shape[0], -1, 12, 16, 16)

        hero = F.softmax(hero, dim=-1)
        index = hero.max(-1, keepdim=True)[1]
        if hard_softmax:
            hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
            hero_hard = (hero_hard - hero).detach() + hero
            hero = hero_hard
        hero = hero.view(map.shape[0], map.shape[1], 4, map.shape[3], map.shape[4])

        map = F.softmax(map, dim=2)
        index = map.max(2, keepdim=True)[1]
        if hard_softmax:
            map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
            map_hard = (map_hard - map).detach() + map
            map = map_hard
        map = map[:,:,:-1,:,:]

        zeros = torch.zeros(*map.shape[0:2], 1, *map.shape[-2:], device='cuda')
        out = torch.cat([hero, zeros, map], 2)

        out = F.pad(out, (1,1,1,1))

        out[:,:,5,:,0] = 1
        out[:,:,5,0,:] = 1
        out[:,:,5,:,-1] = 1
        out[:,:,5,-1,:] = 1

        return out

class Intersection(nn.Module):
    def __init__(self, dim):
        super(Intersection, self).__init__()
        self.dim = dim
        self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim)
        self.layer2 = nn.Linear(2 * self.dim, self.dim)

        nn.init.xavier_uniform_(self.layer1.weight)
        nn.init.xavier_uniform_(self.layer2.weight)

    def forward(self, alpha_embeddings, beta_embeddings):
        all_embeddings = torch.cat([alpha_embeddings, beta_embeddings], dim=-1)
        layer1_act = F.relu(self.layer1(all_embeddings)) # (num_conj, batch_size, 2 * dim)
        attention = F.softmax(self.layer2(layer1_act), dim=1) # (num_conj, batch_size, dim)

        alpha_embedding = torch.sum(attention * alpha_embeddings, dim=1)
        beta_embedding = torch.sum(attention * beta_embeddings, dim=1)

        return alpha_embedding, beta_embedding
