import numpy as np
from scipy.special import logsumexp
from kernel import pairwise_distances, kernel_matrix, compute_bandwidths_from_distances


def mmdfuse(
    X,
    Y,
    seed,
    alpha=0.05,
    kernels=("laplace", "gaussian"),
    lambda_multiplier=1,
    number_bandwidths=10,
    number_permutations=1000,
    return_p_val=False,
):
    """
    Two-Sample MMD-FUSE test.

    Given data from one distribution and data from another distribution,
    return 0 if the test fails to reject the null
    (i.e. data comes from the same distribution),
    or return 1 if the test rejects the null
    (i.e. data comes from different distributions).

    Parameters
    ----------
    X : array_like
        The shape of X must be of the form (m, d) where m is the number
        of samples and d is the dimension.
    Y: array_like
        The shape of X must be of the form (n, d) where m is the number
        of samples and d is the dimension.
    seed: int
        Random seed (can be an integer seed).
    alpha: scalar
        The value of alpha (level of the test) must be between 0 and 1.
    kernels: str or list
        The list should contain strings.
        The value of the strings must be: "gaussian", "laplace", "imq", "matern_0.5_l1",
        "matern_1.5_l1", "matern_2.5_l1", "matern_3.5_l1", "matern_4.5_l1",
        "matern_0.5_l2", "matern_1.5_l2", "matern_2.5_l2", "matern_3.5_l2",
        "matern_4.5_l2".
    lambda_multiplier: scalar
        The value of lambda_multiplier must be positive.
        The regulariser lambda is taken to be np.sqrt(minimum_m_n * (minimum_m_n - 1)) * lambda_multiplier
        where minimum_m_n is the minimum of the sample sizes of X and Y.
    number_bandwidths: int
        The number of bandwidths per kernel to include in the collection.
    number_permutations: int
        Number of permuted test statistics to approximate the quantiles.
    return_p_val: bool
        If true, the p-value is returned.
        If false, the test output Indicator(p_val <= alpha) is returned.

    Returns
    -------
    output : int
        0 if the aggregated MMD-FUSE test fails to reject the null
            (i.e. data comes from the same distribution)
        1 if the aggregated MMD-FUSE test rejects the null
            (i.e. data comes from different distributions)
    """
    # Assertions
    if Y.shape[0] > X.shape[0]:
        X, Y = Y, X
    m = X.shape[0]
    n = Y.shape[0]
    assert n <= m
    assert n >= 2 and m >= 2
    assert 0 < alpha and alpha < 1
    assert lambda_multiplier > 0
    assert number_bandwidths > 1 and type(number_bandwidths) == int
    assert number_permutations > 0 and type(number_permutations) == int
    if type(kernels) is str:
        # convert to list
        kernels = (kernels,)
    for kernel in kernels:
        assert kernel in (
            "imq",
            "rq",
            "gaussian",
            "matern_0.5_l2",
            "matern_1.5_l2",
            "matern_2.5_l2",
            "matern_3.5_l2",
            "matern_4.5_l2",
            "laplace",
            "matern_0.5_l1",
            "matern_1.5_l1",
            "matern_2.5_l1",
            "matern_3.5_l1",
            "matern_4.5_l1",
        )

    # Lists of kernels for l1 and l2
    all_kernels_l1 = (
        "laplace",
        "matern_0.5_l1",
        "matern_1.5_l1",
        "matern_2.5_l1",
        "matern_3.5_l1",
        "matern_4.5_l1",
    )
    all_kernels_l2 = (
        "imq",
        "rq",
        "gaussian",
        "matern_0.5_l2",
        "matern_1.5_l2",
        "matern_2.5_l2",
        "matern_3.5_l2",
        "matern_4.5_l2",
    )
    number_kernels = len(kernels)
    kernels_l1 = [k for k in kernels if k in all_kernels_l1]
    kernels_l2 = [k for k in kernels if k in all_kernels_l2]

    # Setup for permutations
    np.random.seed(seed)
    B = number_permutations
    # (B+1, m+n): rows of permuted indices
    idx = np.array([np.random.permutation(m + n) for _ in range(B + 1)])
    # 11
    v11 = np.concatenate((np.ones(m), -np.ones(n)))  # (m+n, )
    V11i = np.tile(v11, (B + 1, 1))  # (B+1, m+n)
    V11 = np.take_along_axis(
        V11i, idx, axis=1
    )  # (B+1, m+n): permute the entries of the rows
    V11[B] = v11  # (B+1)th entry is the original MMD (no permutation)
    V11 = V11.transpose()  # (m+n, B+1)
    # 10
    v10 = np.concatenate((np.ones(m), np.zeros(n)))
    V10i = np.tile(v10, (B + 1, 1))
    V10 = np.take_along_axis(V10i, idx, axis=1)
    V10[B] = v10
    V10 = V10.transpose()
    # 01
    v01 = np.concatenate((np.zeros(m), -np.ones(n)))
    V01i = np.tile(v01, (B + 1, 1))
    V01 = np.take_along_axis(V01i, idx, axis=1)
    V01[B] = v01
    V01 = V01.transpose()

    # Compute all permuted MMD estimates
    N = number_bandwidths * number_kernels
    M = np.zeros((N, B + 1))
    kernel_count = -1  # first kernel will have kernel_count = 0
    for r in range(2):
        kernels_l = (kernels_l1, kernels_l2)[r]
        l = ("l1", "l2")[r]
        if len(kernels_l) > 0:
            # Pairwise distance matrix
            Z = np.concatenate((X, Y))
            pairwise_matrix = pairwise_distances(Z, metric=l)

            # Collection of bandwidths
            # Use compute_bandwidths_from_distances function from kernel.py
            distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0], k=1)]
            bandwidths = compute_bandwidths_from_distances(distances, number_bandwidths)

            # Compute all permuted MMD estimates for either l1 or l2
            for j in range(len(kernels_l)):
                kernel = kernels_l[j]
                kernel_count += 1
                for i in range(number_bandwidths):
                    # compute kernel matrix and set diagonal to zero
                    bandwidth = bandwidths[i]
                    K = kernel_matrix(pairwise_matrix, kernel, bandwidth, metric=l)
                    np.fill_diagonal(K, 0)
                    # compute standard deviation
                    unscaled_std = np.sqrt(np.sum(K**2))
                    # compute MMD permuted values 
                    # with lambda = np.sqrt(n * (n - 1))
                    M[kernel_count * number_bandwidths + i] = (
                        # following the reasoning of
                        # Schrab et al. MMDAgg Appendix C
                        (
                            np.sum(V10 * (K @ V10), 0)
                            * (n - m + 1) 
                            * (n - 1)
                            / (m * (m - 1))
                            + np.sum(V01 * (K @ V01), 0)
                            * (m - n + 1)
                            / m
                            + np.sum(V11 * (K @ V11), 0) 
                            * (n - 1)
                            / m
                        )
                        / unscaled_std
                        * np.sqrt(n * (n - 1))
                    )

    # Compute permuted and original statistics
    all_statistics = logsumexp(lambda_multiplier * M, axis=0, b=1 / N)  # (B+1,)
    original_statistic = all_statistics[-1]  # (1,)

    # Compute statistics and test output
    p_val = np.mean(all_statistics >= original_statistic)
    output = p_val <= alpha

    # Return output
    if return_p_val:
        return output.astype(int), p_val
    else:
        return output.astype(int)
