import torch
# adopted from https://github.com/josipd/torch-two-sample

def pdist_0(sample_12, norm=2): 
    n = sample_12.size(0)
    I = torch.triu_indices(n,n, offset=1).to(sample_12.device)
    distances = torch.zeros(n,n).to(sample_12.device)
    distances[I[0], I[1]] = torch.pdist(sample_12)
    return (distances + distances.t())

def pdist_1(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.

    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.

    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
                 norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))
    else:
        dim = sample_1.size(1)
        expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
        expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
        differences = torch.abs(expanded_1 - expanded_2) ** norm
        inner = torch.sum(differences, dim=2, keepdim=False)
        return (eps + inner) ** (1. / norm)


class KernelStatistic:

    def __init__(self, n_1):
        self.n_1 = n_1

    def __call__(self, sample_1, sample_2, alphas=None, ret_matrix=False, mode=1):
        r"""Evaluate the statistic.

        The kernel used is

        .. math::

            k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2},

        for the provided ``alphas``.

        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, of size ``(n_1, d)``.
        sample_2: variable of shape (n_2, d)
            The second sample, of size ``(n_2, d)``.
        alphas : list of :class:`float`
            The kernel parameters.
        ret_matrix: bool
            If set, the call with also return a second variable.

            This variable can be then used to compute a p-value using
            :py:meth:`~.MMDStatistic.pval`.

        Returns
        -------
        :class:`float`
            The test statistic.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        sample_12 = torch.cat((sample_1, sample_2), 0)
        #distances = pdist(sample_12, sample_12, norm=2)
        pdist = pdist_1 if mode == '1' else pdist_0
        distances = pdist(sample_12, norm=2)
    
        if alphas is None: 
            alphas = [distances.median()]
            print(alphas)
        
        kernels = None
        for alpha in alphas:
            kernels_a = torch.exp(- alpha * distances ** 2)
            if kernels is None:
                kernels = kernels_a
            else:
                kernels = kernels + kernels_a

        stat = self.statistic(kernels)
        
        if ret_matrix:
            return stat, kernels
        else:
            return stat

    def statistic(self, kernels): 
        raise NotImplementedError
        
    def pval(self, distances, n_permutations=1000):
        raise NotImplementedError


class MMD(KernelStatistic):
    r"""The *unbiased* MMD test of :cite:`gretton2012kernel`.
    
    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample."""

    def __init__(self, n_1, n_2):
        self.n_1 = n_1
        self.n_2 = n_2

        # The three constants used in the test.
        self.a00 = 1. / (n_1 * (n_1 - 1))
        self.a11 = 1. / (n_2 * (n_2 - 1))
        self.a01 = - 1. / (n_1 * n_2)

    def statistic(self, kernels): 
        k_1 = kernels[:self.n_1, :self.n_1]
        k_2 = kernels[self.n_1:, self.n_1:]
        k_12 = kernels[:self.n_1, self.n_1:]

        return (2 * self.a01 * k_12.sum() +
               self.a00 * 2*torch.triu(k_1,diagonal=1).sum() +
               self.a11 * 2*torch.triu(k_2,diagonal=1).sum())

#         return (2 * self.a01 * k_12.sum() +
#                self.a00 * (k_1.sum() - torch.trace(k_1)) +
#                self.a11 * (k_2.sum() - torch.trace(k_2)))

    def pval(self, distances, n_permutations=1000):
        r"""Compute a p-value using a permutation test.

        Arguments
        ---------
        matrix: 
            The matrix computed using `__call__`.
        n_permutations: int
            The number of random draws from the permutation null.

        Returns
        -------
        float
            The estimated p-value. 
        """
        # n = self.n_1 + self.n_2
        n = distances.size(0)
        pi = torch.arange(n)
        
        larger = 0.
        
        all_stats = []
        for sample_n in range(1 + n_permutations):
            stat = 0.
            
            # permute along both dimension
            kernels = distances[pi,:][:,pi]
            stat = self.statistic(kernels)
            
            all_stats.append(stat)
            
            if sample_n == 0:
                statistic = stat
            elif statistic <= stat:
                larger += 1

            pi = torch.randperm(n)

        return larger / n_permutations, all_stats
    

def centering(K):
    n = K.shape[0]
    unit = torch.ones(n, n)
    I = torch.eye(n)
    Q = I - unit/n
    
    return torch.mm(torch.mm(Q, K), Q)

class HSIC(KernelStatistic): 
    r"""The *unbiased* HSIC independence test of :cite:`gretton2007kernel`.

    Arguments
    ---------
    n: int
        The number of points in the first sample."""
    def __init__(self, n): 
        self.n_1 = n

    def statistic(self, kernels):
        n = self.n_1
        k_x = centering(kernels[:n, :n])
        k_y = centering(kernels[n:, n:])
        
        return torch.sum(k_x.t() * k_y) / n

    def pval(self, distances, n_permutations=1000, tqdm=True):
        r"""Compute a p-value using a permutation test, which permutes only the second domain

        Arguments
        ---------
        distances: 
            The matrix computed using `__call__`
        n_permutations: int
            The number of random draws from the permutation null.

        Returns
        -------
        float
            The estimated p-value. 
        """
        n = self.n_1
        pi = torch.arange(n)
        
        
        d_x = distances[:n, :n]
        k_x = centering(d_x)
        d_y = distances[n:, n:]
                
        larger = 0.
        
        all_stats = []
        for sample_n in range(1 + n_permutations):
            hsic = 0.
            
            # permute along both dimension
            d_y_shuffled = d_y[pi,:][:,pi]
            k_y = centering(d_y_shuffled)
            hsic = torch.sum(k_x.t() * k_y) / n
            
            all_stats.append(hsic)
            
            if sample_n == 0:
                statistic = hsic
            elif statistic <= hsic:
                larger += 1

            pi = torch.randperm(n)

        return larger / n_permutations, all_stats