import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt

import random

class MLP(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.d = d
        self.lin = nn.ModuleList([nn.Linear(d[i], d[i+1]) for i in range(len(d)-1)])

    def forward(self, x):
        squeeze_at_end = False
        if len(x.size())==1:
            x = x.unsqueeze(0)
            squeeze_at_end = True
        for i in range(len(self.lin)-1):
            x = F.relu(self.lin[i](x))

        out = self.lin[-1](x)

        if squeeze_at_end:
            out = out.squeeze(0)

        if self.d[-1]==1:
            out = out.squeeze(-1)
        
        return out
