import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.colors import BoundaryNorm


def plot_atom_heatmap():
    input_path = "visualizations/raw_data/atom_groups.csv"
    output_path = "visualizations/plots/atom_groups.png"

    df = pd.read_csv(input_path, index_col=0)

    fig, ax = plt.subplots(figsize=(10, 9))

    boundaries = [
        0,
        0.5,
        1,
        1.5,
        2,
        2.5,
        3,
        4,
        5,
        10,
        20,
        30,
        40,
        50,
        60,
        70,
        80,
        90,
        92,
        94,
        95,
        96,
        97,
        98,
        98.5,
        99,
        100,
    ]
    cmap = plt.get_cmap("viridis")
    norm = BoundaryNorm(boundaries, cmap.N, clip=True)

    sns.heatmap(
        df,
        fmt=".1f",
        cmap=cmap,
        norm=norm,
        cbar_kws={"label": "Percentage"},
        annot_kws={"size": 10, "weight": "bold"},
        linewidths=0.5,
        linecolor="gray",
        ax=ax,
    )

    # first 3 rows black text, rest white
    for i in range(df.shape[0]):
        for j in range(df.shape[1]):
            value = df.iloc[i, j]
            color = "black" if i < 3 else "white"
            ax.text(
                j + 0.5,
                i + 0.5,
                f"{value:.1f}",
                ha="center",
                va="center",
                color=color,
                fontsize=10,
                fontweight="bold",
            )

    ax.tick_params(axis="x", rotation=45, labelsize=12)
    ax.tick_params(axis="y", rotation=0, labelsize=12)

    plt.tight_layout()

    plt.savefig(output_path, dpi=300, bbox_inches="tight")


if __name__ == "__main__":
    plot_atom_heatmap()
