import torch
import torch.nn as nn

class FCNet(nn.Module):

  def __init__(self, n_in, n_hid, n_out, n_layers):
    super(FCNet, self).__init__()
    layers = [nn.Linear(n_in, n_hid)]
    for i in range(n_layers - 2):
      layers.append(nn.ReLU())
      layers.append(nn.Linear(n_hid, n_hid))
    layers.append(nn.ReLU())
    layers.append(nn.Linear(n_hid, n_out))
    self.network = nn.Sequential(*layers)


  def forward(self, x, params=None):
    return self.network(x)