import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist
from scipy.stats import rankdata
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors


def empirical_copula_test(X, Y):
    u = rankdata(X) / len(X)
    v = rankdata(Y) / len(Y)
    return np.corrcoef(u, v)[0, 1]


def compute_distance_correlation(X, Y):
    """
    Compute distance correlation for high-dimensional continuous data.
    Parameters:
        X: np.ndarray of shape (n_samples, n_features)
        Y: np.ndarray of shape (n_samples, n_features)
    Returns:
        Distance Correlation (float)
    Comment:
        - If X and Y are independent, Distance Correlation ≈ 0.
        - If X and Y are dependent, Distance Correlation > 0.
    """
    def distance_matrix(data):
        return squareform(pdist(data, 'euclidean'))

    def double_center(matrix):
        row_mean = matrix.mean(axis=0, keepdims=True)
        col_mean = matrix.mean(axis=1, keepdims=True)
        total_mean = matrix.mean()
        return matrix - row_mean - col_mean + total_mean

    A = double_center(distance_matrix(X))
    B = double_center(distance_matrix(Y))

    dcov = np.sum(A * B) / (X.shape[0] ** 2)
    dvar_X = np.sqrt(np.sum(A * A) / (X.shape[0] ** 2))
    dvar_Y = np.sqrt(np.sum(B * B) / (X.shape[0] ** 2))
    dcor = np.sqrt(dcov / (dvar_X * dvar_Y))
    return dcor


def knn_ind_test(s, s_, k=5, n_tests=100, max_samples: int = 100):
    n_sample = max(min(max_samples, s.shape[0]), 5)
    a, b, c = torch.chunk(torch.randint(0, s.shape[0], (3*n_tests, n_sample)), 3, dim=0)
    X = torch.cat([s[a]/(s.shape[-1]**0.5), s_[a]/(s_.shape[-1]**0.5)], dim=-1)
    Y = torch.cat([s[b]/(s.shape[-1]**0.5), s_[a]/(s_.shape[-1]**0.5)], dim=-1)
    Xbar = torch.cat([s[c]/(s.shape[-1]**0.5), s_[a]/(s_.shape[-1]**0.5)], dim=-1)
    x_acc = knn_accuracy(X, Y, k)
    xbar_acc = knn_accuracy(Xbar, Y, k)  # this serves as the null hypothesis.
    mu1 = x_acc.mean()
    mu2 = xbar_acc.mean()
    std1 = x_acc.std()
    std2 = xbar_acc.std()
    return tvalue(mu1, std1, n_tests, mu2, std2, n_tests)


def knn_accuracy(x: torch.Tensor, y: torch.Tensor, k: int = 5) -> torch.Tensor:
    """
    Compute classifier 2-sample test with k-NN.

    Parameters
    ----------
        x : torch.tensor
            The first dataset.
        y : torch.tensor
            The second dataset.
        k : int
            The number of nearest neighbors.

    Returns
    -------
        x_acc : float
            The accuracy of the classifier on the first dataset.
        y_acc : float
            The accuracy of the classifier on the second dataset.
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    if x.dim() == 2:
        k = min(k, x.shape[0], y.shape[0])
        n_sample = x.shape[0]
        xy = torch.cat([x, y], dim=0)
        xy = xy.to(device)
        diag = torch.diag(torch.full((2*n_sample,), 1e12, dtype=xy.dtype, device=xy.device))
    elif x.dim() == 3:
        k = min(k, x.shape[1], y.shape[1])
        n_sample = x.shape[1]
        xy = torch.cat([x, y], dim=1)
        xy = xy.to(device)
        diag = torch.diag(torch.full((2*n_sample,), 1e12, dtype=xy.dtype, device=xy.device))
        diag = diag.unsqueeze(0).repeat(xy.shape[0], 1, 1)
    else:
        raise ValueError(f"Incompatible tensor dimension: {x.shape}")
    dists = torch.cdist(xy, xy) + diag
    del xy
    _, indexes = dists.topk(k, dim=-1, largest=False)
    x_decisions = (indexes < n_sample).sum(dim=-1) / k
    x_acc = x_decisions[..., :n_sample].sum(dim=-1) / n_sample
    y_acc = 1 - x_decisions[..., n_sample:].sum(dim=-1) / n_sample
    return (x_acc + y_acc) / 2


def independence_test_knn_with_pvalue(X, Y, k=5, num_permutations=100) -> float:
    """
    Perform KNN-based independence test and compute a p-value using permutation test.

    Parameters:
    X: ndarray
        Input features of shape (n_samples, n_features).
    Y: ndarray
        Target variables of shape (n_samples, n_targets).
    k: int
        Number of neighbors for KNN classifier.
    num_permutations: int
        Number of permutations for the null distribution.

    Returns:
    acc_original: float
        Accuracy of the KNN classifier on the original data.
    acc_shuffled: float
        Accuracy of the KNN classifier on shuffled data.
    p_value: float
        p-value from the permutation test.
    """

    # n_samples = X.shape[0]
    # labels_original = np.ones(n_samples)
    # labels_shuffled = np.zeros(n_samples)
    #
    # # Combine original and shuffled data for training/testing
    # X_combined = np.vstack([X, X])
    # Y_combined = np.vstack([Y, Y])
    # labels_combined = np.hstack([labels_original, labels_shuffled])
    #
    # # Shuffle Y for the second half
    # np.random.shuffle(Y_combined[n_samples:])
    #
    # # Split into train and test sets
    # X_train, X_test, Y_train, Y_test, labels_train, labels_test = train_test_split(
    #     X_combined, Y_combined, labels_combined, test_size=0.3, random_state=42
    # )
    #
    # # Train KNN on the original data
    # knn = KNeighborsClassifier(n_neighbors=k)
    # knn.fit(np.hstack([X_train, Y_train]), labels_train)
    #
    # # Compute accuracy on the test set
    # acc_original = knn.score(np.hstack([X_test, Y_test]), labels_test)
    #
    # # Null distribution via permutation
    # null_accuracies = []
    # for _ in range(num_permutations):
    #     # Shuffle the labels
    #     np.random.shuffle(labels_train)
    #     knn.fit(np.hstack([X_train, Y_train]), labels_train)
    #     acc = knn.score(np.hstack([X_test, Y_test]), labels_test)
    #     null_accuracies.append(acc)
    #
    # # Compute p-value: fraction of null accuracies >= observed accuracy
    # null_accuracies = np.array(null_accuracies)
    # p_value = np.mean(null_accuracies >= acc_original)
    # # print("p val", p_value)
    # return p_value


    # combined_data = np.hstack([X, Y])
    # n_samples = combined_data.shape[0]
    # labels = np.arange(n_samples)  # Use artificial unique labels to test independence
    #
    # # Split dataset into train and test sets
    # X_train, X_test, labels_train, labels_test = train_test_split(
    #     combined_data, labels, test_size=0.3, random_state=42
    # )
    #
    # # Initialize k-NN classifier
    # knn = KNeighborsClassifier(n_neighbors=k)
    #
    # # Train and evaluate on the original data
    # knn.fit(X_train, labels_train)
    # acc_original = knn.score(X_test, labels_test)
    #
    # # Null distribution (shuffle Y and recompute accuracies)
    # null_accuracies = []
    # for _ in range(num_permutations):
    #     shuffled_combined = np.hstack([X, np.random.permutation(Y)])
    #     X_train, X_test, labels_train, labels_test = train_test_split(
    #         shuffled_combined, labels, test_size=0.3, random_state=42
    #     )
    #     knn.fit(X_train, labels_train)
    #     acc = knn.score(X_test, labels_test)
    #     null_accuracies.append(acc)
    #
    # # Compute p-value
    # null_accuracies = np.array(null_accuracies)
    # p_value = np.mean(null_accuracies >= acc_original)
    # print("p val", p_value)
    # # return acc_original - np.mean(null_accuracies)
    # return acc_original, null_accuracies, p_value

    #
    #
    #
    #
    # Prepare the dataset: original vs. shuffled labels
    n_samples = X.shape[0]
    labels_original = np.ones(n_samples)
    labels_shuffled = np.zeros(n_samples)

    # Combine original and shuffled data for training/testing
    X_combined = np.vstack([X, X])
    Y_combined = np.vstack([Y, Y])
    labels_combined = np.hstack([labels_original, labels_shuffled])

    # Shuffle Y for the second half
    np.random.shuffle(Y_combined[n_samples:])

    # Split into train and test sets
    X_train, X_test, Y_train, Y_test, labels_train, labels_test = train_test_split(
        X_combined, Y_combined, labels_combined, test_size=0.3, random_state=42
    )

    # Train KNN on the original data
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(np.hstack([X_train, Y_train]), labels_train)

    # Compute accuracy on the test set
    acc_original = knn.score(np.hstack([X_test, Y_test]), labels_test)

    # Null distribution via permutation
    null_accuracies = []
    for _ in range(num_permutations):
        # Shuffle the labels
        np.random.shuffle(labels_train)
        knn.fit(np.hstack([X_train, Y_train]), labels_train)
        acc = knn.score(np.hstack([X_test, Y_test]), labels_test)
        null_accuracies.append(acc)

    # Compute p-value: fraction of null accuracies >= observed accuracy
    null_accuracies = np.array(null_accuracies)
    p_value = np.mean(null_accuracies >= acc_original)
    return p_value


def compute_mutual_information(X, Y):
    """
    Compute mutual information between high-dimensional continuous X and Y.
    Parameters:
        X: np.ndarray of shape (n_samples, n_features)
        Y: np.ndarray of shape (n_samples, n_features)
    Returns:
        Average Mutual Information (float)
    Comment:
        - If X and Y are independent, MI ≈ 0.
        - If X and Y are dependent, MI > 0.
    """
    n_features = Y.shape[1]
    mi_values = []
    for i in range(n_features):
        mi = mutual_info_regression(X, Y[:, i])
        mi_values.append(np.mean(mi))
    return np.mean(mi_values)


def compute_kl_divergence_knn(X, Y, k=5):
    """
    Approximate KL divergence between P(Y|X) and P(Y) using KNN.
    Parameters:
        X: np.ndarray of shape (n_samples, n_features)
        Y: np.ndarray of shape (n_samples, n_features)
        k: int, number of nearest neighbors.
    Returns:
        KL Divergence (float)
    Comment:
        - If X and Y are independent, KL divergence ≈ 0.
        - If X and Y are dependent, KL divergence > 0.
    """
    n = X.shape[0]
    nbrs_joint = NearestNeighbors(n_neighbors=k + 1).fit(np.hstack((X, Y)))
    nbrs_y = NearestNeighbors(n_neighbors=k + 1).fit(Y)

    distances_joint, _ = nbrs_joint.kneighbors(np.hstack((X, Y)))
    distances_y, _ = nbrs_y.kneighbors(Y)

    # Ratio of volumes
    rho = distances_joint[:, -1]
    nu = distances_y[:, -1]

    kl_div = np.mean(np.log(nu / rho))
    return kl_div


def compute_HSIC(X, Y, sigma=1.0):
    """
    Compute HSIC (Hilbert-Schmidt Independence Criterion) for high-dimensional continuous data.
    Parameters:
        X: np.ndarray of shape (n_samples, n_features)
        Y: np.ndarray of shape (n_samples, n_features)
        sigma: float, bandwidth for Gaussian RBF kernel.
    Returns:
        HSIC value (float)
    Comment:
        - If X and Y are independent, HSIC ≈ 0.
        - If X and Y are dependent, HSIC > 0.
    """
    n = X.shape[0]
    K = rbf_kernel(X, gamma=1 / (2 * sigma ** 2))
    L = rbf_kernel(Y, gamma=1 / (2 * sigma ** 2))
    H = np.eye(n) - np.ones((n, n)) / n

    HSIC = np.trace(K @ H @ L @ H) / (n - 1) ** 2
    return HSIC


def compute_mi_knn(X, Y, k=5):
    """
    Compute mutual information using KNN for high-dimensional data.
    Parameters:
        X: np.ndarray of shape (n_samples, n_features)
        Y: np.ndarray of shape (n_samples, n_features)
        k: int, number of nearest neighbors.
    Returns:
        Mutual Information (float)
    Comment:
        - If X and Y are independent, MI ≈ 0.
        - If X and Y are dependent, MI > 0.
    """
    n = X.shape[0]
    xy = np.hstack((X, Y))

    nbrs_xy = NearestNeighbors(n_neighbors=k + 1).fit(xy)
    nbrs_x = NearestNeighbors(n_neighbors=k + 1).fit(X)
    nbrs_y = NearestNeighbors(n_neighbors=k + 1).fit(Y)

    distances_xy, _ = nbrs_xy.kneighbors(xy)
    distances_x, _ = nbrs_x.kneighbors(X)
    distances_y, _ = nbrs_y.kneighbors(Y)

    rho = distances_x[:, -1]
    nu = distances_y[:, -1]
    epsilon = distances_xy[:, -1]

    mi = np.log(epsilon / (rho * nu)).mean()
    return mi


def tvalue(mu1, std1, n1, mu2, std2, n2):
    mu_del = mu1 - mu2
    std_del = ((std1/np.sqrt(n1))**2 + (std2/np.sqrt(n2))**2)**0.5
    if std_del < 1e-8 and mu_del < 1e-8:
        return 0
    elif std_del < 1e-8:
        return (100*mu_del)
    # NOTE: this is basically Welch's t-test (ie, unknown variances),
    # but as I couldn't find a way to use p-values to compare
    # the total MDP errors, I'm instead using the t-value as the error.
    tval = (mu_del / std_del).abs()
    return tval

    # execute below if you want to get the p-value from the t-dist
    # df = ((std1**2/n1 + std2**2/n2)**2)/((std1**2/n1)**2/(n1-1) + (std2**2/n2)**2/(n2-1))
    # statistic = t(df=df).cdf(mu_del/std_del)
    # if mu_del > 0.0:
    #     p_mass = 2*(1 - statistic)
    # else:
    #     p_mass = 2*statistic
    # p_mass = max(p_mass, 1e-16)

    # return -np.log(p_mass)
