import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def analyze_diversity_density_map(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)

    # read from json file
    layer_divs = data["structure_scores"]
    output_divs = data["output_scores"]

    mean_layer = np.mean(layer_divs)
    mean_output = np.mean(output_divs)

    print(f"Average Structure Distribution: {mean_layer:.4f}")
    print(f"Average Output Distribution:   {mean_output:.4f}")
    print(f"Overall Diversity Score:       {data.get('diversity_score', 'N/A')}")

    # --- Hexbin Plot ---
    plt.figure(figsize=(8, 6))
    plt.hexbin(layer_divs, output_divs, gridsize=30, cmap="viridis", mincnt=1)
    plt.colorbar(label="Point Count")
    plt.xlabel("Structure Distribution")
    plt.ylabel("Output Distribution")
    plt.title("2D Density (Hexbin) of AIG Distributions")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # --- Boxplot + Mean Line ---
    plt.figure(figsize=(8, 5))
    data_combined = [layer_divs,output_divs]
    ax = sns.boxplot(data=data_combined, palette=["skyblue", "salmon"])

    # add mean lines
    for i, scores in enumerate(data_combined):
        mean_val = scores.mean()
        plt.plot([i - 0.2, i + 0.2], [mean_val, mean_val], color='black', linestyle='--', linewidth=2)
        plt.text(i, mean_val + 0.01, f"μ={mean_val:.3f}", ha='center', va='bottom', fontsize=10)

    plt.xticks([0, 1], ["Structure", "Output"])
    plt.ylabel("Distribution Score")
    plt.title("Distribution Scores with Mean Lines")
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()


def improved_kde_heatmap(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)

    x = np.array(data["structure_scores"])
    y = np.array(data["output_scores"])
    diversity_score = data.get("diversity_score", "N/A")

    print(f"== Diversity Summary ==")
    print(f"Structure mean: {x.mean():.4f}")
    print(f"Output mean:    {y.mean():.4f}")
    print(f"Diversity Score: {diversity_score}")

    plt.figure(figsize=(8, 6))
    ax = sns.kdeplot(
        x=x,
        y=y,
        cmap="viridis",
        fill=True,
        bw_adjust=0.7,
        levels=100,
        thresh=1e-5,
        cut=3
    )

    # automatically obtain mappable objects
    mappable = ax.collections[0]
    plt.colorbar(mappable, ax=ax, label="Density Estimate")

    plt.xlabel("Structure Distribution Score")
    plt.ylabel("Output Distribution Score")
    plt.title("2D KDE Heatmap of AIG Diversity")
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    path = 'generated_aigs/in10_out10/and30/evaluation.json'
    improved_kde_heatmap(path)
    # analyze_diversity_density_map(path)


