import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from env.statement import num_statements
from env.operator import num_operators
from model.encoder import DenseEncoder, DenseQueryEncoder, RNNEncoder, MLP
from model.model import BaseModel

class QueryMI(BaseModel):
    """Returns a compiled model described in Appendix section C.

    Arguments:
        I (int): number of inputs in each program. input count is
            padded to I with null type and vals.
        E (int): embedding dimension
        M (int): number of examples per program. default 5.

    Returns:
        keras.model compiled keras model as described in Appendix
        section C
    """
    def __init__(self, embedding, vocab_size, input_dropout_p, dropout_p, num_layers, rnn_cell, num_att_layers=2, att_ff_size=256, z_size=10):
        super(QueryMI, self).__init__()
        self.io_recon = params.io_recon
        self.program_recon = params.program_recon
        self.hidden_size = hidden_size = params.dense_output_size 
        self.num_layers = num_layers
        self.encoder_dense = DenseQueryEncoder(embedding)
        self.encoder_dense_t = DenseQueryEncoder(embedding)
        # token: state, length: program length
        self.encoder_rnn = RNNEncoder(vocab_size, hidden_size,
                                      input_dropout_p=input_dropout_p,
                                      dropout_p=dropout_p,
                                      n_layers=num_layers,
                                      bidirectional=False,
                                      rnn_cell=rnn_cell,
                                      variable_lengths=True)
        self.encoder_mu_z = nn.Linear(hidden_size, z_size)
        self.encoder_mu_t = nn.Linear(hidden_size, z_size)
        self.encoder_logvar_z = nn.Linear(hidden_size, z_size)
        self.encoder_logvar_t = nn.Linear(hidden_size, z_size)
        self.decoder_gen = MLP(hidden_size, att_ff_size, hidden_size,
                               num_layers=num_att_layers)
        self.decoder_z = nn.Linear(z_size, hidden_size)
        #TODO: MLP params
        self.attention_z = MLP(2*hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_z2 = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)
        self.attention_t = MLP(hidden_size, att_ff_size, hidden_size, num_layers=num_att_layers)

        if self.io_recon:
            self.io_reconstructor = MLP(
                    z_size, att_ff_size, hidden_size,
                    num_layers=num_att_layers)

        # Setup answer reconstruction.
        if self.program_recon:
            self.program_reconstructor = MLP(
                    z_size, att_ff_size, hidden_size,
                    num_layers=num_att_layers)

        self.list_head = nn.Linear(params.dense_output_size, 3 * params.max_list_len * params.integer_range)
        self.int_head = nn.Linear(params.dense_output_size, 3 * 1 * params.integer_range)

    def forward(self, x, typ, program, plengths, hard_softmax=False):
        io_features = self.encode_io(x, typ)
        program_features = self.encode_program(program, plengths)
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        query, index = self.decode_query(zs, io_features, typ, hard_softmax)
        return query, index
    
    def decode_process(self, x, typ, hard_softmax=False):
        list_ = self.list_head(x)
        list_ = list_.view(x.shape[0], 3, params.max_list_len, params.integer_range)
        list_ = torch.softmax(list_, -1)
        list_ = torch.cat([list_, torch.zeros(*list_.shape[:-1], 1, device='cuda')], -1)

        int_ = self.int_head(x)
        int_ = int_.view(x.shape[0], 3, 1, params.integer_range)
        int_ = torch.softmax(int_, -1)
        int_ = torch.cat([int_, torch.zeros(*int_.shape[:2], params.max_list_len - 1, params.integer_range, device='cuda')], -2)
        int_ = torch.cat([int_, torch.ones(*int_.shape[:-1], 1, device='cuda')], -1)
        int_[:, :, 0, -1] = 0.

        null_ = torch.zeros(*list_.shape, device='cuda')
        null_[:, :, :, -1] = 1.

        list_ = list_.view(x.shape[0], 3, -1).unsqueeze(2)
        int_ = int_.view(x.shape[0], 3, -1).unsqueeze(2)
        null_ = null_.view(x.shape[0], 3, -1).unsqueeze(2)
        choice = torch.cat([int_, list_, null_], 2)
        # input type in 1 example
        x = typ[:, 0, :-1].unsqueeze(2).float() @ choice
        x = x.view(x.shape[0], 3, params.max_list_len, params.integer_range + 1)

        index = x.max(-1, keepdim=True)[1]

        if hard_softmax:
            x_hard = torch.zeros_like(x, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
            x_hard = (x_hard - x).detach() + x
            x = x_hard
        
        x = x.unsqueeze(1)
        return x, index.squeeze(-1).unsqueeze(1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)

    def encode_program(self, program, plengths):
        _, encoder_hidden = self.encoder_rnn(
                program, plengths, None)
        if self.encoder_rnn.rnn_cell == nn.LSTM:
            encoder_hidden = encoder_hidden[0]

        # Pick the hidden vector from the top layer.
        encoder_hidden = encoder_hidden[-1:, :, :].squeeze(0)
        return encoder_hidden

    def encode_io(self, io, typ):
        return self.encoder_dense(io, typ[:, :, :, :2])
    def encode_io_t(self, io, typ):
        return self.encoder_dense_t(io, typ[:, :, :, :2])

    def encode_into_z(self, io_features, program_features):
        together = torch.cat([io_features, program_features], -1)
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z(together)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        return mus, logvars
    def encode_into_z2(self, program_features):
        #print('program_features', program_features.shape)
        #print('io_features', io_features.shape)
        attended_hiddens = self.attention_z2(program_features)
        mus = self.encoder_mu_z(attended_hiddens)
        logvars = self.encoder_logvar_z(attended_hiddens)
        return mus, logvars

    def encode_into_t(self, io_features):
        #print('io_features_t', io_features.shape)
        attended_hiddens = self.attention_t(io_features)
        mus = self.encoder_mu_t(attended_hiddens)
        logvars = self.encoder_logvar_t(attended_hiddens)
        return mus, logvars

    def decode_query(self, zs, io_features, typ, hard_softmax):
        batch_size = zs.size(0)
        z_hiddens = self.decoder_z(zs)
        hiddens = self.decoder_gen(io_features + z_hiddens)

        #print('hiddens:', hiddens.shape)
        result = self.decode_process(hiddens, typ, hard_softmax) 
        return result
    
    def reconstruct_inputs(self, io_features, program_features):
        recon_io_features = None
        recon_program_features = None
        mus, logvars = self.encode_into_z(io_features, program_features)
        zs = self.reparameterize(mus, logvars)
        if self.io_recon:
            recon_io_features = self.io_reconstructor(zs)
        if self.program_recon:
            recon_program_features = self.program_reconstructor(zs)
        return recon_io_features, recon_program_features

    def predict_from_io(self, io, typ, hard_softmax=False):
        #TODO: predict from io
        io_features = self.encode_io(io, typ)
        mus, logvars = self.encode_into_t(io_features)
        ts = self.reparameterize(mus, logvars)
        query, index = self.decode_query(ts, io_features, typ, hard_softmax)
        return query, index
