from math import dist
import numpy as np
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics import mutual_info_score
import torch

from mi_estimators.mi_hsic import hsic_estimate
from mi_estimators.mi_kde import mutual_information_kde
from mi_estimators.mi_correlation import mutual_information_correlation
from mi_estimators.mi_mine import mutual_information_mine
import npeet.entropy_estimators as ee

from causallearn.utils.cit import CIT

import xicorpy


def true_mi_bivariate_normal(rho):
    """Closed-form formula for MI of 2D Gaussian with correlation rho."""
    return -0.5 * np.log(1 - rho**2)


def sample_bivariate_normal(rho, n_samples=20000):
    """Generate samples from a 2D Gaussian with correlation rho."""
    mean = np.array([0.0, 0.0])
    cov = np.array([[1.0, rho], [rho, 1.0]])
    data = np.random.multivariate_normal(mean, cov, size=n_samples)
    X = data[:, 0]
    Y = data[:, 1]
    return X, Y


# NEW: sample bivariate uniform distribution
def sample_bivariate_uniform(rho, n_samples=20000):
    """
    Generate samples from a 2D approximately uniform distribution with controlled correlation.
    A simple way is: sample U ~ Uniform(-1,1), and set
        Y = rho * U + sqrt(1 - rho^2) * noise
    where noise ~ Uniform(-1,1) (independent).
    """
    U = np.random.uniform(-1, 1, size=n_samples)
    noise = np.random.uniform(-1, 1, size=n_samples)
    X = U
    Y = rho * U + np.sqrt(1 - rho**2) * noise
    return X, Y


def estimate_mi_sklearn_regression(X, Y):
    """Estimate MI using sklearn's kNN regression."""
    X_2d = X.reshape(-1, 1)
    mi = mutual_info_regression(X_2d, Y, discrete_features=False)
    return mi[0]


def estimate_mi_binning(X, Y, nbins=200):
    """Estimate MI by binning and using mutual_info_score."""
    x_bins = np.digitize(X, np.histogram_bin_edges(X, bins=nbins)) - 1
    y_bins = np.digitize(Y, np.histogram_bin_edges(Y, bins=nbins)) - 1
    x_bins = np.clip(x_bins, 0, nbins - 1)
    y_bins = np.clip(y_bins, 0, nbins - 1)
    return mutual_info_score(x_bins, y_bins)


def main(distribution="gaussian"):
    np.random.seed(123)

    correlation_values = [0.0, 0.3, 0.7, 0.9]

    for rho in correlation_values:
        # Choose sampling function
        if distribution == "gaussian":
            X, Y = sample_bivariate_normal(rho=rho, n_samples=2000)
            true_mi = true_mi_bivariate_normal(rho)
        elif distribution == "uniform":
            X, Y = sample_bivariate_uniform(rho=rho, n_samples=2000)
            true_mi = None  # No closed-form MI for uniform with dependency
        elif distribution == "independent_uniform":
            X, Y = np.random.uniform(-1, 1, size=(5000, 2)).T
            X = rho * X if rho != 0 else X
            true_mi = 0.0
        elif distribution == "independent_normal":
            X, Y = np.random.normal(0, 1, size=(5000, 2)).T
            X = rho * X if rho != 0 else X
            true_mi = 0.0
        elif distribution == "independent_laplace":
            X, Y = np.random.laplace(0, 1, size=(5000, 2)).T
            X = rho * X if rho != 0 else X
            true_mi = 0.0
        else:
            raise ValueError(f"Unknown distribution: {distribution}")

        print(f"X.shape: {X.shape}")

        # Estimate MI using different methods
        mi_regression = estimate_mi_sklearn_regression(X, Y)
        mi_binning = estimate_mi_binning(X, Y, nbins=20)
        mi_pytorch = mutual_information_kde(
            torch.tensor(X), torch.tensor(Y), bandwidth=0.1
        )
        mi_correlation = mutual_information_correlation(
            torch.tensor(X), torch.tensor(Y)
        )
        mi_mine = mutual_information_mine(
            torch.tensor(X).float().unsqueeze(1),
            torch.tensor(Y).float().unsqueeze(1),
        )
        hsic = hsic_estimate(
            torch.tensor(X).float().unsqueeze(1),
            torch.tensor(Y).float().unsqueeze(1),
        )

        # Estimate MI using NPEET
        npeet = ee.mi(X, Y, k=3)
        npeet_5 = ee.mi(X, Y, k=5)
        npeet_10 = ee.mi(X, Y, k=10)

        codec = xicorpy.compute_conditional_dependence(X, Y)
        data = np.column_stack((X, Y))
        # print(data.shape)
        # print(X.shape)
        # print(Y.shape)
        kci = CIT(data, "kci")(0, 1)

        print(f"--- {distribution.capitalize()} Bivariate with rho={rho} ---")
        if true_mi is not None:
            print(f"True MI (nats):        {true_mi:.5f}")
        else:
            print(f"True MI (nats):        (unknown for uniform)")
        print(f"MI (sklearn regression): {mi_regression:.5f}")
        print(f"MI (binning approach):   {mi_binning:.5f}")
        print(f"MI (KDE pytorch approach): {mi_pytorch:.5f}")
        print(f"MI (Correlation approach): {mi_correlation:.5f}")
        print(f"MI (MINE approach): {mi_mine:.5f}")
        print(f"HSIC (ChatGPT approach): {hsic:.5f}")
        print(f"KCI: {kci:.5f}")
        print(f"NPEET: {npeet:.5f}")
        print(f"NPEET_5: {npeet_5:.5f}")
        print(f"NPEET_10: {npeet_10:.5f}")
        print(f"Codec: {codec:.5f}")
        print()


if __name__ == "__main__":
    main(distribution="gaussian")
    main(distribution="uniform")
    main(distribution="independent_uniform")
    main(distribution="independent_normal")
    main(distribution="independent_laplace")
