# plot_pg_error_boxplots_logy.py
# -*- coding: utf-8 -*-
import argparse
import math
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import font_manager as fm
import numpy as np
import pandas as pd


# Global style: seaborn darkgrid + Arial
plt.style.use("seaborn-v0_8-darkgrid")

font_dir = Path("Arial")
font_files = list(font_dir.glob("*.ttf")) + list(font_dir.glob("*.otf"))
for font_file in font_files:
    fm.fontManager.addfont(str(font_file))
try:
    prop = fm.FontProperties(fname=str(font_dir / "arial.ttf"))
    plt.rcParams["font.family"] = prop.get_name()
except Exception:
    pass

# Allow LaTeX for mathtext only
plt.rcParams["mathtext.fontset"] = "dejavuserif"   # math font only affects ∞
plt.rcParams["mathtext.default"] = "regular"

plt.rcParams.update(
    {
        "font.size": 14,
        "axes.titlesize": 18,
        "axes.labelsize": 16,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 14,
    }
)


def format_pg_name(pg: str) -> str:
    if "inf" in pg:
        return pg.replace("inf", "∞")
    return pg


def plot_pg_errors_boxplot(
    mode: str = "single",
    pred_csv: str = "./qm9_hegnn_16c.csv",
    top_k: int = 12,
    max_cols: int = 4,
):
    if mode not in ("single", "mul"):
        raise ValueError("mode must be 'single' or 'mul'")

    df = pd.read_csv(pred_csv)
    clip_lower = 1e-2

    for l in range(12):
        c1 = f"pred_{l}"
        if c1 in df.columns:
            df[f"err_{l}"] = (df[c1] - df["alpha"]).abs().clip(lower=clip_lower)

        c2 = f"pred_0_{l}"
        if c2 in df.columns:
            df[f"err_0_{l}"] = (df[c2] - df["alpha"]).abs().clip(lower=clip_lower)

    pg_counts = df.groupby("point_group").size().sort_values(ascending=False)
    top_groups = pg_counts.head(top_k).index.tolist()

    print("Top point groups by count:")
    print(pg_counts.head(top_k))
    print()

    n = len(top_groups)
    n_cols = min(max_cols, n)
    n_rows = math.ceil(n / n_cols)

    x = np.arange(12)

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4.6 * n_cols, 3.8 * n_rows),
        sharey=True,
    )

    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = np.array([axes])
    elif n_cols == 1:
        axes = np.array([[ax] for ax in axes])

    axes_flat = axes.flatten()

    cmap = plt.cm.viridis
    colors = cmap(np.linspace(0.05, 0.95, 12))

    for ax, pg in zip(axes_flat, top_groups):
        df_pg = df[df["point_group"] == pg]

        if mode == "single":
            box_data = [df_pg[f"err_{l}"].dropna().values for l in range(12)]
        else:
            box_data = [df_pg[f"err_0_{l}"].dropna().values for l in range(12)]

        bp = ax.boxplot(
            box_data,
            positions=x,
            widths=0.6,
            showfliers=True,
            patch_artist=True,
        )

        for i, box in enumerate(bp["boxes"]):
            box.set_facecolor(colors[i])
            box.set_edgecolor("k")
            box.set_alpha(0.7)
            box.set_linewidth(1.2)

        for w in bp["whiskers"]:
            w.set_color("k")
            w.set_linewidth(1.0)
        for c in bp["caps"]:
            c.set_color("k")
            c.set_linewidth(1.0)
        for m in bp["medians"]:
            m.set_color("k")
            m.set_linewidth(1.2)
        for f in bp["fliers"]:
            f.set_markeredgecolor("k")
            f.set_markerfacecolor("none")
            f.set_markersize(3)

        ax.set_yscale("log")

        pg_label = format_pg_name(pg)
        ax.set_title(f"{pg_label} ({pg_counts[pg]:,} samples)", fontweight="bold")

        ax.set_xticks(x)
        ax.set_xticklabels([str(l) for l in range(12)])
        ax.grid(alpha=0.6, linestyle="--")

    for i in range(len(top_groups), len(axes_flat)):
        axes_flat[i].axis("off")

    last_row = n_cols * (n_rows - 1)
    for i in range(last_row, last_row + n_cols):
        if i < len(top_groups):
            label = "Model Degree l" if mode == "single" else "Model Degree 0 ~ l"
            axes_flat[i].set_xlabel(label, fontsize=16)

    for r in range(n_rows):
        axes[r, 0].set_ylabel("MAE loss (log scale)", fontsize=16)

    plt.tight_layout(rect=[0, 0, 1, 0.95])

    outname = {"single": "pg_error_curves_l.pdf", "mul": "pg_error_curves_0_l.pdf"}[mode]
    plt.savefig(outname, dpi=600, bbox_inches="tight", pad_inches=0)
    print(f"Saved figure to: {outname}")
    plt.show()


def parse_args():
    parser = argparse.ArgumentParser(
        description="Plot per-point-group error distributions across degree l."
    )
    parser.add_argument("--mode", type=str, default="mul", choices=["single", "mul"])
    parser.add_argument("--pred_csv", type=str, default="./qm9_hegnn_16c.csv")
    parser.add_argument("--top_k", type=int, default=16)
    parser.add_argument("--max_cols", type=int, default=4)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    plot_pg_errors_boxplot(
        mode=args.mode,
        pred_csv=args.pred_csv,
        top_k=args.top_k,
        max_cols=args.max_cols,
    )
