
import torch
from torch import nn
import torch.nn.functional as F

def MLP(in_sz,out_sz,hidden=100,layer_len=1):
    assert layer_len>=1
    layers = [nn.Linear(in_sz,hidden),nn.ReLU()]
    for _ in range(layer_len-1):
        layers += [nn.Linear(hidden, hidden), nn.ReLU()]
    layers.append(nn.Linear(hidden,out_sz))
    return nn.Sequential(*layers)
