from torch import nn
import torch


class Model(nn.Module):

    def __init__(self, dim_in=3, nh=3, dim_h=20, dim_out_dist=1, dim_out=1, lr=0.001):

        super(Model, self).__init__()

        # define the encoder
        self.flatten = nn.Flatten()

        self.nh = nh
        self.dim_out = dim_out
        self.dim_out_dist = dim_out_dist
        self.loss_function = nn.MSELoss()
        self.learning_rate = lr

        self.input = nn.Linear(dim_in, dim_h)
        self.hidden = nn.ModuleList([nn.Linear(dim_h, dim_h) for _ in range(nh)])
        self.output = nn.Linear(dim_h, dim_out)

    def forward(self, x):
        h = torch.relu(self.input(x))
        for i in range(self.nh):
            h = torch.relu(self.hidden[i](h))
        y = self.output(h)

        return y
