import itertools
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

datasets = ["cifar", "cub", "imagenet"]
models   = ["convnext", "resnet50", "vit"]

root_in  = Path("")
root_out = Path("./results")
root_out.mkdir(parents=True, exist_ok=True)

alpha_cols    = ['a_center', 'a_triplet' , 'a_align', 'a_ce', 'a_var']
custom_labels = ['α_center','α_triplet', 'α_anchor', 'α_ce', 'α_var']

plt.rcParams.update({
    "font.size": 14,        # Base font size
    "axes.titlesize": 16,   # Title font size
    "axes.labelsize": 14,   # Axis label size
    "xtick.labelsize": 12,  # X-tick label size
    "ytick.labelsize": 12,  # Y-tick label size
    "legend.fontsize": 12   # Legend font size
})
for ds, md in itertools.product(datasets, models):
    csv_path = (
        root_in
        / f"{ds}_{md}"
        / "ce,triplet,align,center,var"
        / "metrics.csv"
    )

    if not csv_path.exists():
        print(f"[WARN] {csv_path} not found – skipping.")
        continue

    df = pd.read_csv(csv_path)

    fig, ax = plt.subplots(figsize=(10, 6))
    handles = [ax.plot(df["epoch"], df[c])[0] for c in alpha_cols]

    ax.set_xlabel("Epoch")
    ax.set_ylabel("Alpha value")
    ax.set_title(f"{ds.upper()} – {md.upper()} α-evolution")
    ax.grid(True)
    ax.legend(handles, custom_labels)
    fig.tight_layout()

    out_file = root_out / f"{ds}_{md}.png"
    fig.savefig(out_file, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"[OK]  Saved: {out_file.resolve()}")