# plots_grouped_by_task.py
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------------
# Config
# -----------------------------
RESULTS_CSV = "experiments/tdc_outputs/core9_with_ci/summary_across_seeds.csv"  # <- change if needed
OUT_DIR = "figs_grouped"
os.makedirs(OUT_DIR, exist_ok=True)

# Consistent feature order and short labels
FEATURE_ORDER = [
    "SAE Features",
    "Transformer Embeddings",
    "PCA on Embeddings",
    "ECFP Fingerprints",
    "SAE ⊕ ECFP",
]
FEATURE_ALIAS = {
    "SAE Features": "SAE",
    "Transformer Embeddings": "LM",
    "PCA on Embeddings": "LM-PCA",
    "ECFP Fingerprints": "ECFP",
    "SAE ⊕ ECFP": "SAE+ECFP",
}

# Choose which tasks appear in each plot
TASKS_CLS = [
    "AMES",
    "BBB_Martins",
    "CYP2D6_Veith",
    "CYP3A4_Veith",
    "Pgp_Broccatelli",
    "hERG",
]
TASKS_REG = ["Caco2_Wang", "Half_Life_Obach", "PPBR_AZ"]

# -----------------------------
# Load aggregated summary
# -----------------------------
agg = pd.read_csv(RESULTS_CSV)
# Expect columns: Task, Features, Metric, Mean, Std, N, Seed_CI_low, Seed_CI_high
# Normalize column names in case they differ
agg = agg.rename(columns={"CI_low_95": "Seed_CI_low", "CI_high_95": "Seed_CI_high"})


def half_ci(row):
    if pd.notnull(row["Seed_CI_low"]) and pd.notnull(row["Seed_CI_high"]):
        return max(row["Mean"] - row["Seed_CI_low"], row["Seed_CI_high"] - row["Mean"])
    return np.nan


agg["HalfCI"] = agg.apply(half_ci, axis=1)


# -----------------------------
# Plot helper
# -----------------------------
def grouped_barchart(
    df,
    metric,
    tasks,
    *,
    log_y=False,
    figsize=(5.0, 5.0),
    legend_ncol=5,
    title_prefix="",
    out_prefix="plot",
    group_spacing=0.30,  # <— NEW: extra gap between benchmark groups
):
    sub = df[(df["Metric"] == metric) & (df["Task"].isin(tasks))].copy()
    if sub.empty:
        print(f"[WARN] No rows for metric={metric}")
        return None, None

    sub["Features"] = pd.Categorical(
        sub["Features"], categories=FEATURE_ORDER, ordered=True
    )
    sub["Task"] = pd.Categorical(sub["Task"], categories=tasks, ordered=True)

    n_tasks = len(tasks)
    n_feat = len(FEATURE_ORDER)
    means = np.full((n_tasks, n_feat), np.nan)
    errs = np.full((n_tasks, n_feat), np.nan)

    for ti, task in enumerate(tasks):
        g = sub[sub["Task"] == task].sort_values("Features")
        vals = g["Mean"].values.astype(float)
        hci = g["HalfCI"].values.astype(float)
        fi = [FEATURE_ORDER.index(f) for f in g["Features"].tolist()]
        means[ti, fi] = vals
        errs[ti, fi] = hci

    fig = plt.figure(figsize=figsize, constrained_layout=False)
    ax = fig.add_subplot(111)

    # --- group centers with extra spacing ---
    x_base = np.arange(n_tasks, dtype=float)
    x = x_base * (1.0 + group_spacing)  # <— pushes groups apart

    # bar layout within each group
    bar_width = 0.18
    offsets = (np.arange(n_feat) - (n_feat - 1) / 2.0) * bar_width

    # draw bars
    for fi in range(n_feat):
        xi = x + offsets[fi]
        ax.bar(
            xi, means[:, fi], width=bar_width, label=FEATURE_ALIAS[FEATURE_ORDER[fi]]
        )

    # black CI caps
    eps = 1e-9
    for fi in range(n_feat):
        xi = x + offsets[fi]
        y = means[:, fi]
        e = errs[:, fi]
        ylow = np.maximum(y - e, eps)
        yhigh = y + e
        for j in range(n_tasks):
            if np.isfinite(y[j]) and np.isfinite(e[j]):
                ax.vlines(xi[j], ylow[j], yhigh[j], linewidth=1.2, colors="black")
                cap_half = bar_width * 0.25
                ax.hlines(
                    ylow[j],
                    xi[j] - cap_half,
                    xi[j] + cap_half,
                    linewidth=1.2,
                    colors="black",
                )
                ax.hlines(
                    yhigh[j],
                    xi[j] - cap_half,
                    xi[j] + cap_half,
                    linewidth=1.2,
                    colors="black",
                )

    if log_y:
        ax.set_yscale("log")

    # Cosmetics
    ax.set_xticks(x)
    ax.set_xticklabels(tasks, rotation=25, ha="right")
    ax.set_xlabel("")

    # --- rename AUROC to ROC-AUC in the figure only ---
    display_metric = "ROC-AUC" if metric.upper() == "AUROC" else metric
    ylab = display_metric + (" (log scale)" if log_y else "")
    ax.set_ylabel(ylab)

    # Title centered on axes; leave room for legend below it

    ax.grid(True, axis="y", alpha=0.3)

    # Legend between title and bars, anchored to axes coords
    handles, labels = ax.get_legend_handles_labels()

    # Layout margins: more top for title+legend; comfy bottom for tick labels
    plt.tight_layout()
    plt.subplots_adjust(top=0.82, bottom=0.20)

    ax.set_title(
        f"{title_prefix}{display_metric} (mean ± 95% CI)",
        loc="center",
        y=1.08,
        pad=1,
    )

    ax.legend(
        handles,
        labels,
        ncol=min(legend_ncol, len(labels)),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.09),  # below title @ 1.08
        bbox_transform=ax.transAxes,
        frameon=False,
        fontsize=9,
        handlelength=1.5,
        columnspacing=1.2,
        handletextpad=0.6,
    )

    safe_metric = display_metric.replace("/", "_").replace("-", "_")
    png = os.path.join(OUT_DIR, f"{out_prefix}_{safe_metric}.png")
    pdf = os.path.join(OUT_DIR, f"{out_prefix}_{safe_metric}.pdf")
    plt.savefig(png, dpi=150, bbox_inches="tight")
    plt.savefig(pdf, bbox_inches="tight")
    plt.close(fig)
    return png, pdf


# -----------------------------
# Make the two plots you asked for
# -----------------------------
# AUROC: linear y (square-ish)
# ROC-AUC (was AUROC), linear y
grouped_barchart(
    agg, "AUROC", TASKS_CLS,
    log_y=False,
    figsize=(5.0, 5.0),
    legend_ncol=5,
    title_prefix="",
    out_prefix="by_task_ROC_AUC_square",
    group_spacing=0.3,   # try 0.30–0.40; 0.35 adds a bit more gap
)

# RMSE (log y)
grouped_barchart(
    agg, "RMSE", TASKS_REG,
    log_y=True,
    figsize=(5.0, 5.0),
    legend_ncol=5,
    title_prefix="",
    out_prefix="by_task_RMSE_square",
    group_spacing=0.1,
)
