
import torch.nn as nn
import torch
import torch.nn.functional as F

__all__ = ["mlp"]


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        #self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.layer_input(x)
        #x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

    def pred_prob(self, x):
        x = self.forward(x)
        x = nn.functional.softmax(x, dim=0)
        return x

# class MLP(nn.Module):
#     def __init__(self, in_features=32*32*3, num_classes=10, hidden_dim=200):
#         super().__init__()
#         self.fc1 = nn.Linear(in_features, hidden_dim)
#         self.fc2 = nn.Linear(hidden_dim, num_classes)
#         self.act = nn.ReLU(inplace=True)

#     def forward(self, x):
#         x = x.view(x.size(0), -1)
#         x = self.act(self.fc1(x))
#         x = self.fc2(x)
#         return F.log_softmax(x, dim=1)


def mlp(args):
    return MLP(args.input, args.hidden, args.num_class)
    

