
import torch
import torch.nn as nn


class EquivariantLayer(nn.Module):


    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.l1 = nn.Linear(in_dim, out_dim)
        self.l2 = nn.Linear(in_dim, out_dim)

    def forward(self, x, mask=None, **kwargs):
        y1 = self.l1(x)
        y2 = self.l2(x.sum(-2, keepdim=True))

        if mask is not None:
            mask = mask[..., 0, None]
            y1 = y1 * mask
            y2 = y2 * mask / mask.sum(-2, keepdim=True)
        else:
            y2 = y2 / x.shape[-2]

        return y1 + y2


class EquivariantNet(nn.Module):
    """
    Neural network with permutation equivariant layers.
    Takes sets of elements of shape (..., N, dim). Permuting the elements across
    the second to last dimension results in the same permutation on the output.
    Args similar to `st.net.MLP` but uses `EquivariantLayer` instead of `nn.Linear`.

    Example:
    >>> net = stribor.net.EquivariantNet(2, [64, 64], 4)

    Args:
        in_dim (int): Input size
        hidden_dims (List[int]): Hidden dimensions
        out_dim (int): Output size
        activation (str, optional): Activation function from `torch.nn`. Default: 'Tanh'
        final_activation (str, optional): Last activation. Default: None
    """

    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()

        self.activation = getattr(nn, activation)()
        self.final_activation = getattr(
            nn, final_activation)() if final_activation else nn.Identity()

        hidden_dims = [in_dim] + hidden_dims + [out_dim]
        self.layers = []
        for in_, out_ in zip(hidden_dims[:-1], hidden_dims[1:]):
            self.layers.append(EquivariantLayer(in_, out_))
        self.layers = nn.ModuleList(self.layers)

    def forward(self, x, mask=None, **kwargs):
        """ For input (..., N, in_dim) returns (..., N, out_dim) """
        for layer in self.layers[:-1]:
            x = layer(x, mask=mask)
            x = self.activation(x)
        x = self.layers[-1](x, mask=mask)
        x = self.final_activation(x)
        return x
