import numpy as np
import torch.nn as nn


class SimpleMLP(nn.Module):
    def __init__(self, obs_space, num_outputs: int):
        super().__init__()

        self.num_features = np.asarray(obs_space).prod()

        self.net = nn.Sequential(
            nn.Linear(self.num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_outputs)
        )

    def forward(self, input):
        input_flat = input.view((len(input), self.num_features))
        return self.net(input_flat)
