import numpy as np
from kernel import pairwise_distances, kernel_matrix


def mmd_median(
    X,
    Y,
    seed,
    alpha=0.05,
    kernel="gaussian",
    B1=1000,
    return_p_val=False,
):
    # Assertions
    m = X.shape[0]
    n = Y.shape[0]
    mn = m + n
    assert n >= 2 and m >= 2
    assert 0 < alpha and alpha < 1
    assert B1 > 0 and type(B1) == int
    if kernel in (
        "imq",
        "rq",
        "gaussian",
        "matern_0.5_l2",
        "matern_1.5_l2",
        "matern_2.5_l2",
        "matern_3.5_l2",
        "matern_4.5_l2",
    ):
        l = "l2"
    elif kernel in (
        "laplace",
        "matern_0.5_l1",
        "matern_1.5_l1",
        "matern_2.5_l1",
        "matern_3.5_l1",
        "matern_4.5_l1",
    ):
        l = "l1"
    else:
        raise ValueError("Kernel not implemented")

    # Setup for permutations
    np.random.seed(seed)
    B = B1
    # (B+1, m+n): rows of permuted indices
    idx = np.array([np.random.permutation(m + n) for _ in range(B + 1)])
    
    # 11
    v11 = np.concatenate((np.ones(m), -np.ones(n)))  # (m+n, )
    V11i = np.tile(v11, (B + 1, 1))  # (B+1, m+n)
    V11 = np.take_along_axis(
        V11i, idx, axis=1
    )  # (B+1, m+n): permute the entries of the rows
    V11[B] = v11  # (B+1)th entry is the original MMD (no permutation)
    V11 = V11.transpose()  # (m+n, B+1)
    # 10
    v10 = np.concatenate((np.ones(m), np.zeros(n)))
    V10i = np.tile(v10, (B + 1, 1))
    V10 = np.take_along_axis(V10i, idx, axis=1)
    V10[B] = v10
    V10 = V10.transpose()
    # 01
    v01 = np.concatenate((np.zeros(m), -np.ones(n)))
    V01i = np.tile(v01, (B + 1, 1))
    V01 = np.take_along_axis(V01i, idx, axis=1)
    V01[B] = v01
    V01 = V01.transpose()

    # Compute kernel matrix
    Z = np.concatenate((X, Y))
    pairwise_matrix = pairwise_distances(Z, metric=l)
    distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0], k=1)]
    bandwidth = np.median(distances)
    K = kernel_matrix(pairwise_matrix, kernel, bandwidth, metric=l, rq_kernel_exponent=0.5)
    np.fill_diagonal(K, 0)  # set diagonal elements to zero
    # Compute MMD permuted values
    M = (
        np.sum(V10 * (K @ V10), 0)  * (1 / (m * (m - 1)) - 1 / (m * n))
        + np.sum(V01 * (K @ V01), 0) * (1 / (n * (n - 1)) - 1 / (m * n))
        + np.sum(V11 * (K @ V11), 0) / (m * n)
    )
    MMD_original = M[B]
    p_val = np.mean(M >= MMD_original)
    output = p_val <= alpha

    # Return output
    if return_p_val:
        return p_val
    else:
        return output.astype(int)
