import random
import torch.nn.functional as F
from torch import nn

import params
import torch
from model.model import BaseModel, PCCoder
from model.query_MI6 import QueryMI
from model.encoder import RNNEncoder
from model.combinar import Combinar

from dsl.program import Program
from dsl.example import Example
from dsl.value import Value

from env.statement import Statement, statement_to_index
from env.statement import num_statements
from env.env import ProgramEnv


class CombinarMI(Combinar):
    def __init__(self, le):
        super(CombinarMI, self).__init__(le)
        self.query = QueryMI(self.embedding, vocab_size=num_statements+3, input_dropout_p=0.0, dropout_p=0.0, rnn_cell='lstm')

    def forward(self, x, var, var_types, program, program_seq, step, drop, typ, plengths, query_step):
        query_inp, query_index = self.query(x, typ[:, :int(max(1, query_step))], program_seq, plengths, params.hard_softmax)
        pred, x, var, var_types = self.ps(query_index, query_inp, program, step, drop, query_step, x, var, var_types)
        return pred, x, var, var_types

    def ps(self, query_index, query_inp, program, step, drop, query_step, x, var, var_types):
        query_io, var_encoded, var_typ = self.env_step(query_index, query_inp, program, step, drop)
#        torch.ones(*x[:, :1].shape, device='cuda'), torch.ones(x.shape[0], 1, 12, 20, 513, device='cuda'), torch.zeros(x.shape[0], 1, 12, 2, device='cuda')
        if query_step > 0:
            x = torch.cat([x, query_io], 1)
            var = torch.cat([var, var_encoded], 1)
            var_types = torch.cat([var_types, var_typ], 1)
        else:
            x = query_io
            var = var_encoded
            var_types = var_typ
        pred = self.model(var, var_types)

        return pred, x, var, var_types
    
    def predict(self, x, var, var_types, program, step, drop, typ, query_step):
        query_inp, query_index = self.query.predict_from_io(x, typ[:, :int(max(1, query_step))], params.hard_softmax)
        pred, x, var, var_types = self.ps(query_index, query_inp, program, step, drop, query_step, x, var, var_types)
        return pred, x, var, var_types
    
    def main_parameters(self):
        params = self.parameters()
        params = filter(lambda p: p.requires_grad, params)
        return params

    def info_parameters(self):
        params = (list(self.query.attention_z.parameters()) +
                  list(self.query.encoder_mu_z.parameters()) +
                  list(self.query.encoder_logvar_z.parameters()))

        # Reconstruction parameters.
        if self.query.io_recon:
            params += list(self.query.io_reconstructor.parameters())
        if self.query.program_recon:
            params += list(self.query.program_reconstructor.parameters())

        params = filter(lambda p: p.requires_grad, params)
        return params