import torch


class MLP(torch.nn.Module):
    """
    [*, in_dim] -> [*, out_dim] (only changes the last dimension)

    Example:
        >>> model = MLP(5, 11)
        >>> y = model(torch.rand(40, 13, 5))
        >>> y.shape
        torch.Size([40, 13, 11])
    """

    def __init__(self, in_dim, out_dim, hidden_dim=8, layer_num=3, activation=torch.nn.functional.relu):
        super().__init__()
        self.linears = torch.nn.ModuleList()
        for i in range(layer_num):
            self.linears.append(torch.nn.Linear(
                in_dim if i == 0 else hidden_dim,
                out_dim if i == layer_num - 1 else hidden_dim
            ))
        self.activation = activation

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i != len(self.linears) - 1:
                x = self.activation(x)
        return x


class DeepSet(torch.nn.Module):
    """
    list([set_size, in_dim]) -> [batch_size, out_dim]

    Example:
        >>> model = DeepSet(MLP(7, 32), MLP(32, 13))
        >>> y = model([torch.rand(10, 7), torch.rand(17, 7)])
        >>> y.shape
        torch.Size([2, 13])
    """

    def __init__(self, phi, rho, use_sigmoid=False):
        super().__init__()
        self.phi = phi
        self.rho = rho
        self.use_sigmoid = use_sigmoid

    def _mid_feat_one(self, X):
        """ [set_size, in_dim] -> [1, out_dim] """
        z = self.phi.forward(torch.Tensor(X)).sum(dim=0, keepdim=True)
        return z

    def forward(self, X_list):
        z_batch = torch.cat([self._mid_feat_one(X) for X in X_list])
        y_batch = self.rho(z_batch)
        if self.use_sigmoid:
            y_batch = torch.sigmoid(y_batch)
        return y_batch


if __name__ == '__main__':
    import doctest
    doctest.testmod()
