import numpy as np
from kernel import pairwise_distances, kernel_matrix, compute_bandwidths_from_distances


def mmd_split(
    X,
    Y,
    seed,
    alpha=0.05,
    split_ratio=0.5,
    kernel="gaussian",
    number_bandwidths=10,
    number_permutations=1000,
    return_p_val=False,
):
    # Assertions
    m = X.shape[0]
    n = Y.shape[0]
    # If sample sizes differ, randomly select samples to make them equal
    if m != n:
        min_size = min(m, n)
        if min_size < 2:
            raise ValueError("After subsampling, at least one sample must have at least 2 observations")
        np.random.seed(seed + 12345)  # Use a different seed offset for subsampling
        if m > n:
            indices = np.random.choice(m, size=n, replace=False)
            X = X[indices]
            m = n
        else:
            indices = np.random.choice(n, size=m, replace=False)
            Y = Y[indices]
            n = m
    assert n >= 2 and m >= 2
    assert 0 < alpha and alpha < 1
    assert 0 < split_ratio and split_ratio < 1
    assert number_bandwidths > 1 and type(number_bandwidths) == int
    assert number_permutations > 0 and type(number_permutations) == int
    assert kernel in (
        "imq",
        "rq",
        "gaussian",
        "matern_0.5_l2",
        "matern_1.5_l2",
        "matern_2.5_l2",
        "matern_3.5_l2",
        "matern_4.5_l2",
        "laplace",
        "matern_0.5_l1",
        "matern_1.5_l1",
        "matern_2.5_l1",
        "matern_3.5_l1",
        "matern_4.5_l1",
    )
    if kernel in (
        "laplace",
        "matern_0.5_l1",
        "matern_1.5_l1",
        "matern_2.5_l1",
        "matern_3.5_l1",
        "matern_4.5_l1",
    ):
        l = "l1"
    elif kernel in (
        "imq",
        "rq",
        "gaussian",
        "matern_0.5_l2",
        "matern_1.5_l2",
        "matern_2.5_l2",
        "matern_3.5_l2",
        "matern_4.5_l2",
    ):
        l = "l2"
    
    # Setup random seed
    np.random.seed(seed)
    
    # Shuffle the data (for data splitting)
    X_shuffle = np.random.permutation(X)
    Y_shuffle = np.random.permutation(Y)
    
    # Split the data
    split = int(n * split_ratio)
    X_selection = X_shuffle[:split]
    Y_selection = Y_shuffle[:split]
    X_test = X_shuffle[split:]
    Y_test = Y_shuffle[split:]
    
    ################## 
    # Kernel selection
    ##################
    N = number_bandwidths
    R = np.zeros((N, ))
    # Pairwise distance matrix
    Z = np.concatenate((X_selection, Y_selection))
    pairwise_matrix = pairwise_distances(Z, metric=l)

    # Collection of bandwidths
    distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0], k=1)]
    bandwidths = compute_bandwidths_from_distances(distances, number_bandwidths)

    # Compute all permuted MMD estimates for either l1 or l2
    for i in range(number_bandwidths):
        # compute kernel matrix and set diagonal to zero
        bandwidth = bandwidths[i]
        K = kernel_matrix(pairwise_matrix, kernel, bandwidth, metric=l)
        R[i] = ratio_mmd_std(K)
    index_selected = np.argmax(R)
    bandwidth_selected = bandwidths[index_selected]
    
    # Setup for permutations
    n = X_test.shape[0]
    m = Y_test.shape[0]
    B = number_permutations
    # (B+1, m+n): rows of permuted indices
    idx = np.array([np.random.permutation(m + n) for _ in range(B + 1)])
    # 11
    v11 = np.concatenate((np.ones(m), -np.ones(n)))  # (m+n, )
    V11i = np.tile(v11, (B + 1, 1))  # (B+1, m+n)
    V11 = np.take_along_axis(
        V11i, idx, axis=1
    )  # (B+1, m+n): permute the entries of the rows
    V11[B] = v11  # (B+1)th entry is the original MMD (no permutation)
    V11 = V11.transpose()  # (m+n, B+1)
    # 10
    v10 = np.concatenate((np.ones(m), np.zeros(n)))
    V10i = np.tile(v10, (B + 1, 1))
    V10 = np.take_along_axis(V10i, idx, axis=1)
    V10[B] = v10
    V10 = V10.transpose()
    # 01
    v01 = np.concatenate((np.zeros(m), -np.ones(n)))
    V01i = np.tile(v01, (B + 1, 1))
    V01 = np.take_along_axis(V01i, idx, axis=1)
    V01[B] = v01
    V01 = V01.transpose()
    
    ##########
    # Run test
    ##########
    Z = np.concatenate((X_test, Y_test))
    pairwise_matrix = pairwise_distances(Z, metric=l)
    K = kernel_matrix(pairwise_matrix, kernel, bandwidth_selected, metric=l)
    M = (
        np.sum(V10 * (K @ V10), 0)  * (1 / (m * (m - 1)) - 1 / (m * n))
        + np.sum(V01 * (K @ V01), 0) * (1 / (n * (n - 1)) - 1 / (m * n))
        + np.sum(V11 * (K @ V11), 0) / (m * n)
    )
    
    # Compute test output
    all_MMD = M  # (B+1,)
    original_MMD = M[-1]  # (1,)
    p_val = np.mean(all_MMD >= original_MMD)
    output = p_val <= alpha

    # Return output
    if return_p_val:
        return p_val
    else:
        return output.astype(int)

    
def mmd_split_different_kernels(
    X,
    Y,
    seed,
    alpha=0.05,
    split_ratio=0.5,
    kernels=("imq", "gaussian", "laplace", "matern_0.5_l2", "matern_1.5_l2", "matern_1.5_l1"),
    number_bandwidths=20,
    number_permutations=2000,
    return_p_val=False,
):
    # Assertions
    m = X.shape[0]
    n = Y.shape[0]
    # If sample sizes differ, randomly select samples to make them equal
    if m != n:
        min_size = min(m, n)
        if min_size < 2:
            raise ValueError("After subsampling, at least one sample must have at least 2 observations")
        np.random.seed(seed + 12345)  # Use a different seed offset for subsampling
        if m > n:
            indices = np.random.choice(m, size=n, replace=False)
            X = X[indices]
            m = n
        else:
            indices = np.random.choice(n, size=m, replace=False)
            Y = Y[indices]
            n = m
    assert n >= 2 and m >= 2
    assert 0 < alpha and alpha < 1
    assert 0 < split_ratio and split_ratio < 1
    assert number_bandwidths > 1 and type(number_bandwidths) == int
    assert number_permutations > 0 and type(number_permutations) == int
    if type(kernels) is str:
        # convert to list
        kernels = (kernels,)
    for kernel in kernels:
        assert kernel in (
            "imq",
            "rq",
            "gaussian",
            "matern_0.5_l2",
            "matern_1.5_l2",
            "matern_2.5_l2",
            "matern_3.5_l2",
            "matern_4.5_l2",
            "laplace",
            "matern_0.5_l1",
            "matern_1.5_l1",
            "matern_2.5_l1",
            "matern_3.5_l1",
            "matern_4.5_l1",
        )

    # Lists of kernels for l1 and l2
    all_kernels_l1 = (
        "laplace",
        "matern_0.5_l1",
        "matern_1.5_l1",
        "matern_2.5_l1",
        "matern_3.5_l1",
        "matern_4.5_l1",
    )
    all_kernels_l2 = (
        "imq",
        "rq",
        "gaussian",
        "matern_0.5_l2",
        "matern_1.5_l2",
        "matern_2.5_l2",
        "matern_3.5_l2",
        "matern_4.5_l2",
    )
    number_kernels = len(kernels)
    kernels_l1 = [k for k in kernels if k in all_kernels_l1]
    kernels_l2 = [k for k in kernels if k in all_kernels_l2]
    
    # Setup random seed
    np.random.seed(seed)
    
    # Shuffle the data
    X_shuffle = np.random.permutation(X)
    Y_shuffle = np.random.permutation(Y)
    
    # Split the data
    split = int(n * split_ratio)
    X_selection = X_shuffle[:split]
    Y_selection = Y_shuffle[:split]
    X_test = X_shuffle[split:]
    Y_test = Y_shuffle[split:]
    
    # Kernel selection
    N = number_bandwidths * number_kernels
    R = np.zeros((N, ))
    kernel_l_bandwidth = np.zeros((N, 3))  # kernel_index, l_index, bandwidth
    kernel_count = -1  # first kernel will have kernel_count = 0
    for r in range(2):
        kernels_l = (kernels_l1, kernels_l2)[r]
        l = ("l1", "l2")[r]
        if len(kernels_l) > 0:
            # Pairwise distance matrix
            Z = np.concatenate((X_selection, Y_selection))
            pairwise_matrix = pairwise_distances(Z, metric=l)

            # Collection of bandwidths
            distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0], k=1)]
            bandwidths = compute_bandwidths_from_distances(distances, number_bandwidths)

            # Compute all permuted MMD estimates for either l1 or l2
            for j in range(len(kernels_l)):
                kernel = kernels_l[j]
                kernel_count += 1
                for i in range(number_bandwidths):
                    # compute kernel matrix and set diagonal to zero
                    bandwidth = bandwidths[i]
                    K = kernel_matrix(pairwise_matrix, kernel, bandwidth, metric=l)
                    kernel_l_bandwidth[kernel_count * number_bandwidths + i] = np.array([j, r, bandwidth])
                    R[kernel_count * number_bandwidths + i] = ratio_mmd_std(K)
    index_selected = np.argmax(R)
    kernel_index, l_index, bandwidth_selected = kernel_l_bandwidth[index_selected]
    kernel_selected = (kernels_l1, kernels_l2)[int(l_index)][int(kernel_index)]
    l_selected = ("l1", "l2")[int(l_index)]
    
    # Setup for permutations
    n = X_test.shape[0]
    m = Y_test.shape[0]
    B = number_permutations
    # (B+1, m+n): rows of permuted indices
    idx = np.array([np.random.permutation(m + n) for _ in range(B + 1)])
    # 11
    v11 = np.concatenate((np.ones(m), -np.ones(n)))  # (m+n, )
    V11i = np.tile(v11, (B + 1, 1))  # (B+1, m+n)
    V11 = np.take_along_axis(
        V11i, idx, axis=1
    )  # (B+1, m+n): permute the entries of the rows
    V11[B] = v11  # (B+1)th entry is the original MMD (no permutation)
    V11 = V11.transpose()  # (m+n, B+1)
    # 10
    v10 = np.concatenate((np.ones(m), np.zeros(n)))
    V10i = np.tile(v10, (B + 1, 1))
    V10 = np.take_along_axis(V10i, idx, axis=1)
    V10[B] = v10
    V10 = V10.transpose()
    # 01
    v01 = np.concatenate((np.zeros(m), -np.ones(n)))
    V01i = np.tile(v01, (B + 1, 1))
    V01 = np.take_along_axis(V01i, idx, axis=1)
    V01[B] = v01
    V01 = V01.transpose()
    
    # Run test
    Z = np.concatenate((X_test, Y_test))
    pairwise_matrix = pairwise_distances(Z, metric=l_selected)
    K = kernel_matrix(pairwise_matrix, kernel_selected, bandwidth_selected, metric=l_selected)
    M = (
        np.sum(V10 * (K @ V10), 0)
        * (n - m + 1)
        / (m * n * (m - 1))
        + np.sum(V01 * (K @ V01), 0)
        * (m - n + 1)
        / (m * n * (n - 1))
        + np.sum(V11 * (K @ V11), 0) / (m * n)
    )
    
    # Compute test output
    all_MMD = M  # (B+1,)
    original_MMD = M[-1]  # (1,)
    p_val = np.mean(all_MMD >= original_MMD)
    output = p_val <= alpha

    # Return output
    if return_p_val:
        return p_val
    else:
        return output.astype(int)

def ratio_mmd_std(K):
    n = int(K.shape[0]/2)
    regulariser = 10 ** (-8)

    # compute variance
    Kxx = K[:n, :n]
    Kxy = K[:n, n:]
    Kyx = K[n:, :n]
    Kyy = K[n:, n:]
    H_column_sum = (
        np.sum(Kxx, axis=1)
        + np.sum(Kyy, axis=1)
        - np.sum(Kxy, axis=1)
        - np.sum(Kyx, axis=1)
    )
    var = (
        4 / n ** 3 * np.sum(H_column_sum ** 2)
        - 4 / n ** 4 * np.sum(H_column_sum) ** 2
    )
    var = np.maximum(var, 0)
    var = var + regulariser

    # compute MMD_a estimate
    Kxx = K[:n, :n]
    Kxy = K[:n, n:]
    Kyy = K[n:, n:]
    np.fill_diagonal(Kxx, 0)
    np.fill_diagonal(Kyy, 0)
    s = np.ones(n)
    mmd = (
        s @ Kxx @ s / (n * (n - 1))
        + s @ Kyy @ s / (n * (n - 1))
        - 2 * s @ Kxy @ s / (n ** 2)
    )
    
    return mmd / np.sqrt(var)
