import random
import torch.nn.functional as F
from torch import nn

import params
import torch
from model.model import BaseModel
from model.query_MI6 import QueryMI
from model.encoder import RNNEncoder
from model.combinar import Combinar


class CombinarMI(Combinar):
    def __init__(self, le=None):
        super(CombinarMI, self).__init__(le)
        self.query = QueryMI(vocab_size=params.vocab_size, input_dropout_p=0.0, dropout_p=0.0, rnn_cell='lstm')
    
    def main_parameters(self):
        params = self.parameters()
        params = filter(lambda p: p.requires_grad, params)
        return params

    def info_parameters(self):
        params = (list(self.query.attention_z.parameters()) +
                  list(self.query.encoder_mu_z.parameters()) +
                  list(self.query.encoder_logvar_z.parameters()))

        # Reconstruction parameters.
        if self.query.io_recon:
            params += list(self.query.io_reconstructor.parameters())
        if self.query.program_recon:
            params += list(self.query.program_reconstructor.parameters())

        params = filter(lambda p: p.requires_grad, params)
        return params