from torch import nn


class BisimNet(nn.Module):
    def __init__(self, state_dim, hidden_dim=96, num_layers=1):
        super().__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.input_layer = nn.Linear(2 * self.state_dim, self.hidden_dim)
        if self.num_layers > 1:
            net_list = []
            for i in range(num_layers - 1):
                net_list.append(nn.Linear(hidden_dim, hidden_dim))
                net_list.append(nn.ReLU())
            self.linears = nn.ModuleList(net_list)
        self.output_layer = nn.Linear(self.hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, states):
        out = self.flatten(states)
        out = self.input_layer(out)
        out = self.relu(out)

        if self.num_layers > 1:
            for i, layer in enumerate(self.linears):
                out = layer(out)

        out = self.output_layer(out)
        out = self.sigmoid(out)
        return out
