import torch


class AbelianGroupNetwork(torch.nn.Module):
    def __init__(self, invertible_net):
        super().__init__()
        self.phi = invertible_net

    def _mid_feat_one(self, positive=[], negative=[]):
        """ ([pos_set_size, dim], [neg_set_size, dim]) -> [dim] """
        pos_mid = self.phi(torch.Tensor(positive)).sum(dim=0)
        neg_mid = self.phi(torch.Tensor(negative)).sum(dim=0)
        return torch.unsqueeze(pos_mid - neg_mid, 0)

    def forward_one(self, positive=[], negative=[]):
        return self.phi(self._mid_feat_one(positive, negative), rev=True)

    def forward(self, X_batch):
        """ list([set_size_i, dim]) -> [batch_size, dim] """
        _, dim = X_batch[0].shape
        mid_batch = torch.cat([self.phi(torch.Tensor(X)).sum(dim=0) for X in X_batch]).view(-1, dim)
        return self.phi(mid_batch, rev=True)


class BaseAbelianSemigroupNetwork(torch.nn.Module):
    def __init__(self, invertible_net):
        super().__init__()
        self.phi = invertible_net  # [*, dim] -> [*, dim]

    def random_parameter(self):
        ret = torch.nn.Parameter(torch.FloatTensor(1))
        ret.data.uniform_(-0.1, 0.1)
        return ret

    def binary_op_one(self, x1, x2):
        """
        not batch!
        [dim], [dim] -> [dim]
        """
        raise NotImplementedError

    def _compose_one(self, X):
        """ [set_size, dim] -> [1, dim] """
        n, dim = X.shape
        if n == 1:
            return X
        return self.binary_op_one(self._compose_one(X[:n // 2]), self._compose_one(X[n // 2:]))

    def _mid_feat_one(self, X):
        """ [set_size, dim] -> [1, dim] """
        return self._compose_one(self.phi(torch.Tensor(X)))

    def forward(self, X_batch):
        """ list([set_size_i, dim]) -> [batch_size, dim] """
        mid_batch = torch.cat([self._mid_feat_one(X) for X in X_batch])  # [batch_size, dim]
        return self.phi(mid_batch, rev=True)


class AbelianSemigroupNetwork1(BaseAbelianSemigroupNetwork):
    def __init__(self, invertible_net):
        super().__init__(invertible_net)
        self.alpha0 = self.random_parameter()
        self.alpha1 = self.random_parameter()

    def binary_op_one(self, x1, x2):
        return self.alpha0 + self.alpha1 * (x1 + x2) + self.alpha1 * (self.alpha1 - 1) * x1 * x2 / self.alpha0


class AbelianSemigroupNetwork2(BaseAbelianSemigroupNetwork):
    def __init__(self, invertible_net):
        super().__init__(invertible_net)
        self.alpha2 = self.random_parameter()

    def _mid_feat_one(self, X):
        """ [set_size, dim] -> [1, dim] """
        return torch.prod(torch.Tensor(X), dim=0, keepdim=True)


class AbelianSemigroupNetwork3(BaseAbelianSemigroupNetwork):
    def __init__(self, invertible_net):
        super().__init__(invertible_net)
        self.alpha2 = self.random_parameter()

    def binary_op_one(self, x1, x2):
        return x1 + x2 + x1 * x2 * self.alpha2
