from torch import nn


class ValueMLP(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.value_func = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(self.hidden_dim, self.hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(self.hidden_dim, self.output_dim))

    def forward(self, data):
        value = self.value_func(data)
        return value

