# plot_kfold_heatmap_three_segment_fixed_upper.py
# -*- coding: utf-8 -*-
import argparse
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import font_manager as fm
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.ticker import FixedLocator, FixedFormatter
import numpy as np
import pandas as pd


# Font: load Arial if available
font_dir = Path("Arial")
font_files = list(font_dir.glob("*.ttf")) + list(font_dir.glob("*.otf"))
for font_file in font_files:
    fm.fontManager.addfont(str(font_file))
try:
    prop = fm.FontProperties(fname=str(font_dir / "arial.ttf"))
    plt.rcParams["font.family"] = prop.get_name()
except Exception:
    pass

font_offset = 4
plt.rcParams.update(
    {
        "font.size": 14 + font_offset,
        "axes.titlesize": 18 + font_offset,
        "axes.labelsize": 16 + font_offset,
        "xtick.labelsize": 12 + font_offset,
        "ytick.labelsize": 12 + font_offset,
        "legend.fontsize": 14 + font_offset,
    }
)

THR1 = 1e-6
THR2 = 1e-3
COLORBAR_UPPER = 1e-1


def build_heatmap_matrix(csv_path: str, agg: str = "max"):
    df = pd.read_csv(csv_path)
    df["fold"] = pd.to_numeric(df["fold"], errors="coerce")

    folds = [2, 3, 4, 6]
    combos = [(f, "2d") for f in folds] + [(f, "3d") for f in folds]

    x_folds = [f for f, _ in combos]
    degs = list(range(11, -1, -1))
    y_labels = [str(l) for l in degs]

    Z = np.zeros((len(degs), len(combos)), dtype=float)

    for j, (f, r) in enumerate(combos):
        block = df[(df["fold"] == f) & (df["rot_mod"] == r)]
        for i, ell in enumerate(degs):
            col = f"deg{ell}"
            vals = block[col].dropna().values if col in block else np.array([])

            if vals.size == 0:
                val = 0.0
            else:
                if agg == "mean":
                    val = vals.mean()
                elif agg == "median":
                    val = np.median(vals)
                elif agg == "min":
                    val = vals.min()
                else:
                    val = vals.max()

            Z[i, j] = val

    return x_folds, y_labels, Z


def plot_kfold_heatmap(pred_csv: str, agg: str):
    x_folds, y_labels, Z = build_heatmap_matrix(pred_csv, agg)

    Z_clean = np.nan_to_num(Z, nan=0.0)

    # Three-segment colormap: royalblue -> light cyan -> gold, alpha=0.6
    colors = [
        (65/255, 105/255, 225/255, 0.4),   # royalblue for [0, 1e-6)
        (127/255, 219/255, 255/255, 0.4),  # light cyan for [1e-6, 1e-4)
        (255/255, 215/255, 0/255, 0.4),    # gold for [1e-4, 1e-1]
    ]
    cmap = ListedColormap(colors)

    # Bounds for three segments
    bounds = [0, THR1, THR2, COLORBAR_UPPER]
    norm = BoundaryNorm(bounds, cmap.N, clip=True)

    fig, ax = plt.subplots(figsize=(8, 5))

    im = ax.imshow(
        Z_clean,
        origin="lower",
        cmap=cmap,
        norm=norm,
        aspect="auto",
    )

    grid_lw = 1.0
    n_cols = Z.shape[1]
    n_rows = Z.shape[0]

    # White grid lines
    for c in range(n_cols + 1):
        ax.axvline(c - 0.5, color="white", linewidth=grid_lw)
    for r in range(n_rows + 1):
        ax.axhline(r - 0.5, color="white", linewidth=grid_lw)

    # 2D / 3D separation line
    ax.axvline(3.5, color="black", linewidth=grid_lw)

    ax.set_xticks(range(n_cols))
    ax.set_xticklabels([f"{f}-fold" for f in x_folds])

    ax.set_yticks(range(n_rows))
    ax.set_yticklabels(y_labels)

    ax.set_ylabel("Degree of Models", labelpad=10)

    # Colorbar with LaTeX-like exponential labels
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    ticks = [0, THR1, THR2, COLORBAR_UPPER]
    tick_labels = [
        r"$0$",
        r"$10^{-6}$",
        r"$10^{-3}$",
        r"$1$",
    ]
    cbar.set_ticks(ticks)
    cbar.set_ticklabels(tick_labels)
    cbar.set_label("Embedding Difference Norm\n(Max of All Settings)")

    total_cols = n_cols
    col_w = 1.0 / total_cols
    y_annot = 1.02

    # 2D label
    start_2d = 0 * col_w
    end_2d = 4 * col_w
    ax.annotate(
        "",
        xy=(end_2d, y_annot),
        xytext=(start_2d, y_annot),
        xycoords="axes fraction",
        arrowprops=dict(arrowstyle="-", lw=1.2, color="black"),
    )
    ax.text(
        (start_2d + end_2d) / 2,
        y_annot + 0.02,
        "2D Rotation",
        ha="center",
        transform=ax.transAxes,
        fontweight="bold",
    )

    # 3D label
    start_3d = 4 * col_w
    end_3d = 8 * col_w
    ax.annotate(
        "",
        xy=(end_3d, y_annot),
        xytext=(start_3d, y_annot),
        xycoords="axes fraction",
        arrowprops=dict(arrowstyle="-", lw=1.2, color="black"),
    )
    ax.text(
        (start_3d + end_3d) / 2,
        y_annot + 0.02,
        "3D Rotation",
        ha="center",
        transform=ax.transAxes,
        fontweight="bold",
    )

    plt.tight_layout()
    outname = "kfold_heatmap.pdf"
    plt.savefig(outname, transparent=True, dpi=600, bbox_inches="tight", pad_inches=0)
    print(f"Saved figure to {outname}")
    plt.show()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pred_csv", type=str, default="./kfold_emb_degree_norm.csv")
    parser.add_argument("--agg", type=str, default="max")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    plot_kfold_heatmap(args.pred_csv, args.agg)
