import torch
from torch import nn
import math


class MLPNet(nn.Module):
    def __init__(self, dropout, in_dim=4096, out_dim=2):
        super(MLPNet, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),     # 4096 * 512
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=dropout),      # dropout = 0.5
            nn.Linear(32, out_dim),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x


class MLPNet_No_Softmax(nn.Module):
    def __init__(self, dropout, in_dim=4096, out_dim=2):
        super(MLPNet_No_Softmax, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),  # 4096 * 512
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=dropout),  # dropout = 0.5
            nn.Linear(32, out_dim),
        )

    def forward(self, x):
        x = self.layers(x)
        return x
