import torch
import torch.nn as nn


class MLP(nn.Module):
    """
    Fully-connected neural network.
    """

    def __init__(self, layer_sizes, activation=nn.SiLU()):
        super().__init__()
        self.activation = activation
        self.linears = torch.nn.ModuleList()
        for i in range(1, len(layer_sizes)):
            self.linears.append(
                torch.nn.Linear(
                    layer_sizes[i - 1], layer_sizes[i], dtype=torch.float32
                )
            )

    def forward(self, inputs):
        x = inputs
        for j, linear in enumerate(self.linears[:-1]):
            x = (
                self.activation[j](linear(x))
                if isinstance(self.activation, list)
                else self.activation(linear(x))
            )
        x = self.linears[-1](x)
        return x
