import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dims, last_op=None, dropout_prob=0.3):
        super(MLP, self).__init__()

        self.dims = dims
        self.skip_layer = [int(len(dims) / 2)]
        self.last_op = last_op
        # self.dropout_prob = dropout_prob
        # self.dropout = nn.Dropout(dropout_prob)

        self.layers = []
        for l in range(0, len(dims) - 1):
            if l in self.skip_layer:
                self.layers.append(nn.Conv1d(dims[l] + dims[0], dims[l + 1], 1))
            else:
                self.layers.append(nn.Conv1d(dims[l], dims[l + 1], 1))
            self.add_module("conv%d" % l, self.layers[l])

    def forward(self, latet_code, return_all=False):
        y = latet_code
        tmpy = latet_code
        y_list = []
        for l, f in enumerate(self.layers):
            if l in self.skip_layer:
                y = self._modules['conv' + str(l)](torch.cat([y, tmpy], 1))
            else:
                y = self._modules['conv' + str(l)](y)
            if l != len(self.layers) - 1:
                y = F.leaky_relu(y)
                # y = self.dropout(y)
        if self.last_op:
            y = self.last_op(y)
            y_list.append(y)
        if return_all:
            return y_list
        else:
            return y