from data_generation.data_from_dict import data_from_dict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import entropy, gaussian_kde
import pandas as pd
from data_generation.data_from_dict import data_from_dict
from discovery_baselines import conditional_entropy_knn, knn_entropy_estimator


def analyze_distributions(x, y, calculate_entropy=True):
    if calculate_entropy:
        # Estimate probability densities using Kernel Density Estimation (KDE)
        kde_x = gaussian_kde(x)
        kde_y = gaussian_kde(y)
        kde_xy = gaussian_kde(np.vstack([x, y]))

        # Define entropy computation for a continuous variable using KDE
        def differential_entropy(kde, samples):
            pdf_values = kde(samples)
            return -np.mean(np.log(pdf_values + 1e-10))  # Small value to avoid log(0)

        # Compute entropy H(X) and H(Y)
        Hx = differential_entropy(kde_x, x)
        Hy = differential_entropy(kde_y, y)

        # Compute joint entropy H(X, Y)
        def joint_entropy(kde_xy, samples_x, samples_y):
            xy_samples = np.vstack([samples_x, samples_y])
            pdf_values = kde_xy(xy_samples)
            return -np.mean(np.log(pdf_values + 1e-10))

        Hxy = joint_entropy(kde_xy, x, y)

        # Compute conditional entropy H(Y|X) = H(X, Y) - H(X) and H(X|Y) = H(X, Y) - H(Y)
        H_y_given_x = Hxy - Hx
        H_x_given_y = Hxy - Hy

        mutual_info = Hx + Hy - Hxy

        # Print entropy values
        print(f"Entropy H(X): {Hx:.4f}")
        print(f"Entropy H(Y): {Hy:.4f}")
        print(f"Joint Entropy H(X, Y): {Hxy:.4f}")
        print(f"Conditional Entropy H(X|Y): {H_x_given_y:.4f}")
        print(f"Conditional Entropy H(Y|X): {H_y_given_x:.4f}")

        print(f"Entropy KNN {knn_entropy_estimator(x)}")
        print(f"Entropy KNN {knn_entropy_estimator(y)}")
        print(f"Conditional entropy H(X|Y): {conditional_entropy_knn(x, y)}")
        print(f"Conditional entropy H(Y|X): {conditional_entropy_knn(y, x)}")

    # Create a single figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Plot joint density as a 2D KDE
    sns.kdeplot(x=x, y=y, cmap="Blues", fill=True, ax=axes[0, 0])
    axes[0, 0].set_xlabel("X")
    axes[0, 0].set_ylabel("Y")
    axes[0, 0].set_title("Joint Density P(X, Y)")
    axes[0, 0].axis("equal")

    # Plot marginal density estimates
    sns.histplot(x, kde=True, bins=30, ax=axes[0, 1])
    axes[0, 1].set_title("Marginal Distribution P(X)")
    sns.histplot(y, kde=True, bins=30, ax=axes[0, 2])
    axes[0, 2].set_title("Marginal Distribution P(Y)")

    # Plot conditional distributions P(Y | X)
    x_values = np.linspace(min(x), max(x), 100)
    for xi in np.linspace(np.percentile(x, 10), np.percentile(x, 90), 5):
        conditional_kde = gaussian_kde(y[(x > xi - 0.2) & (x < xi + 0.2)])
        axes[1, 0].plot(x_values, conditional_kde(x_values), label=f"X ≈ {xi:.2f}")
    axes[1, 0].set_xlabel("Y")
    axes[1, 0].set_ylabel("Density")
    axes[1, 0].set_title("Conditional Distribution P(Y | X)")
    axes[1, 0].legend()

    # Plot conditional distributions P(X | Y)
    y_values = np.linspace(min(y), max(y), 100)
    for yi in np.linspace(np.percentile(y, 10), np.percentile(y, 90), 5):
        conditional_kde = gaussian_kde(x[(y > yi - 0.2) & (y < yi + 0.2)])
        axes[1, 1].plot(y_values, conditional_kde(y_values), label=f"Y ≈ {yi:.2f}")
    axes[1, 1].set_xlabel("X")
    axes[1, 1].set_ylabel("Density")
    axes[1, 1].set_title("Conditional Distribution P(X | Y)")
    axes[1, 1].legend()

    # Remove empty subplot (or use for a text summary)
    axes[1, 2].axis("off")
    if calculate_entropy:
        text_summary = f"""
        Entropy H(X): {Hx:.4f}
        Entropy H(Y): {Hy:.4f}
        Joint Entropy H(X, Y): {Hxy:.4f}
        Conditional Entropy H(Y|X): {H_y_given_x:.4f}
        Conditional Entropy H(X|Y): {H_x_given_y:.4f}
        Mutual Information I(X; Y): {mutual_info:.4f}
        """
        axes[1, 2].text(0.1, 0.5, text_summary, fontsize=12, verticalalignment="center")

    plt.tight_layout()
    plt.show()


def main():
    noise_type = "uniform"
    transform_type = "tanh"

    data_generation_config = {
        "X": {
            "type": noise_type,
            "length": 10000,
        },
        "transformation": {
            "type": transform_type,
            "args": {
                # "alpha": 0.1,
                "num_hidden": 10,
                "num_parents": 1,
            },
        },
        "shape": "sequence",
        "depth": 2,
        # "seed": 42,
        "noise_type": noise_type,
        "noise_parameters": {"mean": 0, "std": 1},
        "standardize": True,
    }
    df = pd.DataFrame(
        data_from_dict(
            data_generation_config,
        )
    )
    x = np.array(df[0])
    y = np.array(df[1])

    # Set the flag to control entropy calculation
    calculate_entropy = True
    analyze_distributions(x, y, calculate_entropy)


if __name__ == "__main__":
    main()
