import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from env.statement import num_statements
from env.operator import num_operators
from model.encoder import * 
from model.model import BaseModel

class QueryMI(BaseModel):
    """Returns a compiled model described in Appendix section C.

    Arguments:
        I (int): number of inputs in each program. input count is
            padded to I with null type and vals.
        E (int): embedding dimension
        M (int): number of examples per program. default 5.

    Returns:
        keras.model compiled keras model as described in Appendix
        section C
    """
    def __init__(self, embedding, vocab_size, input_dropout_p, dropout_p, rnn_cell, num_layers=1, num_att_layers=2, att_ff_size=256, z_size=10):
        super(QueryMI, self).__init__()
        num_layers = 1
        self.io_recon = params.io_recon
        self.program_recon = params.program_recon
        self.hidden_size = hidden_size = params.dense_output_size 
        self.num_layers = num_layers
#        self.encoder_dense = DenseQueryEncoder(embedding)
        self.encoder_dense = DensePerIOEncoder(embedding)
        self.encoder_dense_t = DenseQueryEncoder(embedding)
        # token: state, length: program length
        self.encoder_rnn = RNNEncoder(vocab_size, int(hidden_size / 2),
                                      input_dropout_p=input_dropout_p,
                                      dropout_p=dropout_p,
                                      n_layers=num_layers,
                                      bidirectional=False,
                                      rnn_cell=rnn_cell,
                                      variable_lengths=True)
        self.encoder_transformer = TransformerEncoder(vocab_size, int(hidden_size / 2),
                                              n_layers=num_layers,
                                              n_heads=8,
                                              pf_dim=hidden_size,
                                              dropout=dropout_p,
                                              max_length=params.max_prog_len+2)
        self.encoder_mu_z = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_mu_t = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_logvar_z = nn.Linear(hidden_size, params.dist_dim)
        self.encoder_logvar_t = nn.Linear(hidden_size, params.dist_dim)
        self.decoder_gen = MLP(hidden_size, att_ff_size, hidden_size,
                               num_layers=num_att_layers)
        self.decoder_z = nn.Linear(z_size, hidden_size)
        #TODO: MLP params
        self.attention_z = MLP(2*hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_z2 = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_t = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)

        self.intersection = Intersection(params.dist_dim)


        if self.io_recon:
            self.io_reconstructor = MLP(
                    z_size, att_ff_size, hidden_size,
                    num_layers=num_att_layers)

        # Setup answer reconstruction.
        if self.program_recon:
            self.program_reconstructor = MLP(
                    z_size, att_ff_size, hidden_size,
                    num_layers=num_att_layers)

        ############ latent code #############
        if params.latent_code:
            #self.decoder_dense = DenseDecoder(params.dense_output_size + 1, params.dense_output_size)
            #self.list_head = MLP(hidden_size+1, hidden_size, 3 * params.max_list_len * params.integer_range, num_layers=5)
            #self.int_head = MLP(hidden_size+1, hidden_size, 3 * 1 * params.integer_range, num_layers=5)
            self.list_head = nn.Linear(params.dense_output_size + 1, 3 * params.max_list_len * params.integer_range)
            self.int_head = nn.Linear(params.dense_output_size + 1, 3 * 1 * params.integer_range)
            self.latent_decoder =  MLP(hidden_size, hidden_size, 5, num_layers=2)
        else:
            self.decoder_dense = DenseDecoder(params.dense_output_size, params.dense_output_size)
            #self.list_head = nn.Linear(params.dense_output_size, 3 * params.max_list_len * params.integer_range)
            #self.int_head = nn.Linear(params.dense_output_size, 3 * 1 * params.integer_range)
        #self.list_head = MLP(hidden_size, hidden_size, 3 * params.max_list_len * params.integer_range, num_layers=2)
        #self.int_head = MLP(hidden_size, hidden_size, 3 * 1 * params.integer_range, num_layers=2)

    def forward(self, x, typ, program, plengths, hard_softmax=False):
        io_features = self.encode_io(x, typ)
        program_features = self.encode_program(program, plengths)
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        query, index = self.decode_query(zs, io_features, typ, hard_softmax)
        return query, index

    def decode_process(self, x, typ, hard_softmax=False):
        #x = self.decoder_dense(x)
        list_ = self.list_head(x)
        list_ = list_.view(x.shape[0], 3, params.max_list_len, params.integer_range)
        #list_ = torch.softmax(list_, -1)
        list_ = torch.cat([list_, torch.ones(*list_.shape[:-1], 1, device='cuda')*1e-10], -1)

        int_ = self.int_head(x)
        int_ = int_.view(x.shape[0], 3, 1, params.integer_range)
        #int_ = torch.softmax(int_, -1)
        int_ = torch.cat([int_, torch.ones(*int_.shape[:2], params.max_list_len - 1, params.integer_range, device='cuda')*1e-10], -2)
        int_ = torch.cat([int_, torch.ones(*int_.shape[:-1], 1, device='cuda')*1e10], -1)
        int_[:, :, 0, -1] = 1e-10

        null_ = torch.ones(*list_.shape, device='cuda')*1e-10
        null_[:, :, :, -1] = 1e10

        list_ = list_.view(x.shape[0], 3, -1).unsqueeze(2)
        int_ = int_.view(x.shape[0], 3, -1).unsqueeze(2)
        null_ = null_.view(x.shape[0], 3, -1).unsqueeze(2)
        choice = torch.cat([int_, list_, null_], 2)
        # input type in 1 example
        x = typ[:, 0, :-1].unsqueeze(2).float() @ choice
        x = x.view(x.shape[0], 3, params.max_list_len, params.integer_range + 1)

        if params.gumbel_softmax:
            x_hard = F.gumbel_softmax(x, tau=1, hard=True)
            x = x_hard
            index = x.max(-1, keepdim=True)[1]
        elif hard_softmax:
            x = torch.softmax(x, -1)
            index = x.max(-1, keepdim=True)[1]
            x_hard = torch.zeros_like(x, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
            x_hard = (x_hard - x).detach() + x
            x = x_hard
        else:
            x = torch.softmax(x, -1)
            index = x.max(-1, keepdim=True)[1]
        
        x = x.unsqueeze(1)
        return x, index.squeeze(-1).unsqueeze(1)

    

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)

    def encode_program(self, program, plengths):
        _, encoder_hidden = self.encoder_rnn(
                program, plengths, None)
        if self.encoder_rnn.rnn_cell == nn.LSTM:
            encoder_hidden = encoder_hidden[0]

        # Pick the hidden vector from the top layer.
        encoder_hidden = encoder_hidden[-1:, :, :].squeeze(0)
        return encoder_hidden
    def encode_program2(self, program, plengths):
        encoder_hidden = self.encoder_transformer(program)
        return encoder_hidden

    def encode_io(self, io, typ):
        return self.encoder_dense(io, typ[:, :, :, :2])
    def encode_io_t(self, io, typ):
        return self.encoder_dense_t(io, typ[:, :, :, :2])

    def encode_into_z(self, io_features, program_features):
        together = torch.cat([io_features, program_features], -1)
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z(together)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        return mus, logvars
    def encode_into_z2(self, program_features):
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z2(program_features)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        #logvars = torch.ones(*mus.shape, device='cuda') * 0.1

        return mus, logvars

    def encode_into_t(self, io_features):
        attended_hiddens = self.attention_t(io_features)
        mus = self.encoder_mu_t(attended_hiddens)
        logvars = self.encoder_logvar_t(attended_hiddens)
        return mus, logvars

    def decode_query(self, zs, io_features, typ, hard_softmax):
        batch_size = zs.size(0)
        z_hiddens = self.decoder_z(zs)
        hiddens = self.decoder_gen(io_features + z_hiddens)

        #print('hiddens:', hiddens.shape)
        result = self.decode_process(hiddens, typ, hard_softmax) 
        return result
    def decode_query2(self, zs, io_features):
        batch_size = zs.size(0)
        hiddens = zs
        hiddens = self.decoder_z(hiddens)
        #hiddens = self.decoder_gen(io_features + hiddens)
        return hiddens 
    
    def reconstruct_inputs(self, io_features, program_features):
        recon_io_features = None
        recon_program_features = None
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        if self.io_recon:
            recon_io_features = self.io_reconstructor(zs)
        if self.program_recon:
            recon_program_features = self.program_reconstructor(zs)
        return recon_io_features, recon_program_features

    def predict_from_io(self, io, typ, hard_softmax=False):
        #TODO: predict from io
        io_features = self.encode_io(io, typ)
        mus, logvars = self.encode_into_t(io_features)
        ts = self.reparameterize(mus, logvars)
        query, index = self.decode_query(ts, io_features, typ, hard_softmax)
        return query, index
    
    def distance(self, dist1, dist2):
        pass

    def decode_process_bak(self, x, typ, hard_softmax=False):
        #x = self.decoder_dense(x)
        list_ = self.list_head(x)
        list_ = list_.view(x.shape[0], 3, params.max_list_len, params.integer_range)
        list_ = torch.softmax(list_, -1)
        list_ = torch.cat([list_, torch.zeros(*list_.shape[:-1], 1, device='cuda')], -1)

        int_ = self.int_head(x)
        int_ = int_.view(x.shape[0], 3, 1, params.integer_range)
        int_ = torch.softmax(int_, -1)
        int_ = torch.cat([int_, torch.zeros(*int_.shape[:2], params.max_list_len - 1, params.integer_range, device='cuda')], -2)
        int_ = torch.cat([int_, torch.ones(*int_.shape[:-1], 1, device='cuda')], -1)
        int_[:, :, 0, -1] = 0.

        null_ = torch.zeros(*list_.shape, device='cuda')
        null_[:, :, :, -1] = 1.

        list_ = list_.view(x.shape[0], 3, -1).unsqueeze(2)
        int_ = int_.view(x.shape[0], 3, -1).unsqueeze(2)
        null_ = null_.view(x.shape[0], 3, -1).unsqueeze(2)
        choice = torch.cat([int_, list_, null_], 2)
        # input type in 1 example
        x = typ[:, 0, :-1].unsqueeze(2).float() @ choice
        x = x.view(x.shape[0], 3, params.max_list_len, params.integer_range + 1)

        if params.gumbel_softmax:
            x_hard = F.gumbel_softmax(x, tau=1, hard=True)
            x = x_hard
            index = x.max(-1, keepdim=True)[1]
        elif hard_softmax:
            index = x.max(-1, keepdim=True)[1]
            x_hard = torch.zeros_like(x, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
            x_hard = (x_hard - x).detach() + x
            x = x_hard
        else:
            index = x.max(-1, keepdim=True)[1]
        
        x = x.unsqueeze(1)
        return x, index.squeeze(-1).unsqueeze(1)

class Intersection(nn.Module):
    def __init__(self, dim):
        super(Intersection, self).__init__()
        self.dim = dim
        self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim)
        self.layer2 = nn.Linear(2 * self.dim, self.dim)

        nn.init.xavier_uniform_(self.layer1.weight)
        nn.init.xavier_uniform_(self.layer2.weight)

    def forward(self, alpha_embeddings, beta_embeddings):
        all_embeddings = torch.cat([alpha_embeddings, beta_embeddings], dim=-1)
        layer1_act = F.relu(self.layer1(all_embeddings)) # (batch_size, num_conj, 2 * dim)
        attention = F.softmax(self.layer2(layer1_act), dim=1) # (batch_size, num_conj, dim)

        alpha_embedding = torch.sum(attention * alpha_embeddings, dim=1)
        beta_embedding = torch.sum(attention * beta_embeddings, dim=1)

        return alpha_embedding, beta_embedding
