import torch
from autoencoders.torch_utils import make_mask
from torch.nn.modules.pooling import AvgPool1d
from torch.nn import Sequential, ReLU, Linear


def random_sampling_initialization(X, X_length, num_centroids):
    # TODO: sample num_centroids among the Xs
    pass


def avgpool(X, X_length, num_centroids):
    kernel = X_length[0] / num_centroids[0]
    kernel = torch.ceil(kernel).long().item()
    pooling = AvgPool1d(kernel, padding=int(kernel / 2))

    X = X.transpose(2, 1)
    out = pooling(X)
    # we want to do SumPooling instead of average pooling
    out = out * kernel
    out = out.transpose(2, 1)
    return out, num_centroids


def init_with_fixed(X, X_length, num_centroids, centroids=None):
    max_num_centroids = num_centroids.max()
    bsize = X.size(0)
    out_centroids = torch.zeros(
        (bsize, max_num_centroids, centroids.size(-1)), device=X.device)

    for i in range(bsize):
        out_centroids[i, :num_centroids[i]] = centroids[:num_centroids[i]]

    return out_centroids, num_centroids


def set_to_k(X, X_length, k=3):
    num_centroids = torch.ones_like(X_length) * k
    return num_centroids


def reduce_by_k(X, X_length, k=3):
    num_centroids = torch.ceil(X_length / k)
    return num_centroids


class PairwiseValueFunction(torch.nn.Module):
    def __init__(self, dim):
        super(PairwiseValueFunction, self).__init__()

        self.mlp = Sequential(Linear(2 * dim, dim),
                              ReLU(), Linear(dim, dim))

    def forward(self, x, u):
        """
        Pairs all x and u and returns something of size feature-size
        """
        # x : [batch_size, x_len, feature_size]
        # u : [batch_size, u_len, feature_size]
        xlen = x.size(1)
        ulen = u.size(1)
        x = x.unsqueeze(2).expand(-1, -1, ulen, -1)
        u = u.unsqueeze(1).expand(-1, xlen, -1, -1)

        # out : [batch_size, x_len, u_len, feature_size]
        out = self.mlp(torch.cat([x, u], dim=-1))
        return out


class CentroidAttention(torch.nn.Module):

    def __init__(self, dim, T=1, alpha=1., initialization_f=avgpool, num_centroids_f=None, reduction=1):
        super(CentroidAttention, self).__init__()

        self.initialization_f = initialization_f

        if num_centroids_f is None:
            def num_centroids_f(x, xlen): return reduce_by_k(
                x, xlen, reduction)
            self.num_centroids_f = num_centroids_f
        else:
            self.num_centroids_f = num_centroids_f

        self.similarity_softmax = torch.nn.Softmax(dim=-1)
        self.T = T
        self.epsilon = 1 / T
        self.alpha = alpha
        self.value_function = PairwiseValueFunction(dim)
        self.linear_transform = Linear(dim, dim)

    def _similarity_function(self, X, centroids, centroid_mask):

        # X : [batch_size, max_X_length, feature_size]
        # centroids : [batch_size, max_centroid_length, feature_size]

        Q = self.linear_transform(X)  # [batch_size,max_X_length, feature_size]
        # [batch_size, max_centroid length, feature_size]
        K = self.linear_transform(centroids)

        # [batch_size, max_x_length, max_centroid_length]
        dotproduct = torch.matmul(Q, K.transpose(2, 1))

        assert dotproduct.size(0) == X.size(0) and dotproduct.size(
            1) == X.size(1), dotproduct.size(2) == centroids.size(1)

        # [batch_size, max_x_length, max_centroid_length]
        dotproduct = dotproduct * self.alpha

        # mask softmax
        centroid_mask = centroid_mask.unsqueeze(1).expand_as(dotproduct)
        dotproduct[centroid_mask] = -10e9

        # [batch_size, max_x_length, max_centroid_length]
        weights = self.similarity_softmax(dotproduct)

        assert weights.size(0) == X.size(0) and weights.size(
            1) == X.size(1), weights.size(2) == centroids.size(1)

        return weights

    def forward(self, X, X_length):
        """
        Implements the centroid attention module.
        """
        # X : [batch_size, max_X_length, feature_size]
        # X_length : [batch_size]

        # [batch_size, max_centroid_length]
        num_centroids = self.num_centroids_f(X, X_length)

        # [batch_size, max_centroid_length, feature_size] , [batch_size, max_centroid_length]
        centroids, centroid_length = self.initialization_f(
            X, X_length, num_centroids)

        bsize = X.size(0)
        X_mask = make_mask(bsize, X.size(1), X_length)
        X_mask = X_mask.unsqueeze(-1).unsqueeze(2)
        centroid_mask = make_mask(bsize, centroids.size(1), centroid_length)

        for _ in range(self.T):

            # [batch_size, max_X_length, max_centroid_length]
            similarities = self._similarity_function(
                X, centroids, centroid_mask)

            # [batch_size, max_X_length, max_centroid_length, feature_size]
            values = self.value_function(X, centroids)

            # [batch_size, max_centroid_length, feature_size]
            similarities = similarities.unsqueeze(-1)
            updates = (similarities * values * X_mask).sum(dim=1)

            centroids = centroids + self.epsilon * updates

        out = centroids
        out_length = centroid_length
        return out, out_length
