import torch
import torch.nn as nn

class NN(torch.nn.Module):
    def __init__(self, k_list, emb_dim, num_layer, dropout_prob=0.2):
        super(NN, self).__init__()
        self.k_list = k_list
        self.input_dim = sum(k_list)
        self.emb_dim = emb_dim
        self.num_layer = num_layer

        self.lift_up_layer = nn.Sequential(nn.Linear(self.input_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU())

        self.mid_layers = nn.ModuleList()
        for layer in range(num_layer):
            self.mid_layers.append(nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.LeakyReLU(0.1), nn.Dropout(p=dropout_prob)))

        self.squat_layers = nn.ModuleList()
        for feature_id in range(len(k_list)):
            self.squat_layers.append(nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.LeakyReLU(0.1), nn.Linear(emb_dim, k_list[feature_id]), nn.Sigmoid()))

    def forward(self, x):
        x = self.lift_up_layer(x)

        for layer in range(self.num_layer):
            x = self.mid_layers[layer](x)

        for feature_id in range(len(self.k_list)):
            out_temp = self.squat_layers[feature_id](x)
            if feature_id == 0:
                output = out_temp
            else:
                output = torch.cat((output, out_temp), 1)

        return output