# experiments/tdc_benchmarks.py
from __future__ import annotations
import argparse, os, math
from typing import Dict, Any, List, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import xgboost as xgb
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.decomposition import PCA
from sklearn.metrics import (
    mean_squared_error,
    mean_absolute_error,
    roc_auc_score,
    average_precision_score,
)
from sklearn.preprocessing import LabelEncoder

# CI across seeds (t-interval)
from scipy.stats import t as student_t

# Optional plotting
try:
    import matplotlib.pyplot as plt

    _HAS_PLT = True
except Exception:
    _HAS_PLT = False

# --- TDC loaders: ADME / Tox / HTS (single_pred) ---
# Docs: get_split(method='scaffold', seed=...) supported.  :contentReference[oaicite:1]{index=1}
from tdc.single_pred import ADME, Tox, HTS

# --- Your project's imports ---
from lmkit.sparse.sae import SAEKit
from lmkit.sparse import utils as sae_utils


# ------------------------------
# Utilities
# ------------------------------
def pad_and_stack(
    seqs: List[List[int]] | List[np.ndarray], pad_value: int
) -> np.ndarray:
    if not seqs:
        return np.zeros((0, 0), dtype=np.int32)
    max_len = max(len(s) for s in seqs)
    out = np.full((len(seqs), max_len), pad_value, dtype=np.int32)
    for i, s in enumerate(seqs):
        arr = np.asarray(s, dtype=np.int32)
        out[i, : arr.shape[0]] = arr
    return out


def _find_smiles_column(df: pd.DataFrame) -> str:
    for cand in ("Drug", "SMILES", "smiles", "X"):
        if cand in df.columns:
            return cand
    raise KeyError(f"Could not find a SMILES column among: {list(df.columns)}")


def generate_ecfp_features(smiles_list: List[str], radius=2, n_bits=2048) -> np.ndarray:
    feats = np.zeros((len(smiles_list), n_bits), dtype=np.float32)
    for i, s in enumerate(smiles_list):
        m = Chem.MolFromSmiles(s)
        if m is None:
            continue
        fp = AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=n_bits)
        bits = np.frombuffer(fp.ToBitString().encode(), "S1") == b"1"
        feats[i, :] = bits.astype(np.float32, copy=False)
    return feats


def generate_transformer_and_sae_features_for_layer(
    smiles_list: List[str], sae_kit: SAEKit, sae_layer: int, batch_size: int = 64
) -> Tuple[np.ndarray, np.ndarray]:
    dense_features, sae_features = [], []
    for i in range(0, len(smiles_list), batch_size):
        batch_smiles = smiles_list[i : i + batch_size]
        encs = [sae_kit.tokenizer.encode(s) for s in batch_smiles]
        inputs_np = pad_and_stack([e.ids for e in encs], sae_kit.tokenizer.pad_token_id)
        positions_np = pad_and_stack(
            [np.arange(len(e.ids), dtype=np.int32) for e in encs], -1
        )
        inputs = jnp.asarray(inputs_np, dtype=jnp.int32)
        positions = jnp.asarray(positions_np, dtype=jnp.int32)

        residuals = sae_utils.run_and_capture(
            sae_kit.run_fn,
            inputs,
            positions,
            sae_kit.lm_params,
            sae_kit.lm_config,
            sae_kit.hooks,
        )
        cfg = sae_kit.sae_configs[sae_layer]
        resid = residuals[(cfg.layer_id, cfg.placement)]  # (B, T, H)

        sae_acts = sae_kit.get_encoded(
            inputs, positions, layer_id=sae_layer
        )  # (B, T, K)

        mask = (positions >= 0).astype(jnp.float32)
        L = mask.sum(axis=-1, keepdims=True)
        dense_pool = (resid * mask[..., None]).sum(axis=1) / jnp.maximum(L, 1.0)
        sae_pool = (sae_acts * mask[..., None]).max(axis=1)

        dense_features.append(np.asarray(dense_pool, dtype=np.float32))
        sae_features.append(np.asarray(sae_pool, dtype=np.float32))
    return np.vstack(dense_features), np.vstack(sae_features)


def concat_sae_across_layers(
    smiles: List[str], sae_kit: SAEKit, layers: List[int], batch_size=64
):
    dense_first, sae_all = None, []
    for j, L in enumerate(layers):
        d, z = generate_transformer_and_sae_features_for_layer(
            smiles, sae_kit, L, batch_size=batch_size
        )
        if dense_first is None:
            dense_first = d
        sae_all.append(z)
    return dense_first, np.concatenate(sae_all, axis=1)


def infer_task_type_from_y(y: np.ndarray) -> str:
    y = pd.Series(y).dropna()
    if y.empty:
        return "regression"
    u = np.unique(y.values)
    if y.dtype.kind in ("i", "u"):
        return "classification" if u.size <= 10 else "regression"
    if y.dtype.kind == "f":
        if u.size <= 10 and set(np.unique(np.round(u, 6))).issubset({0.0, 1.0}):
            return "classification"
        return "regression"
    return "classification"


def prepare_labels_for_classification(y_train, y_valid, y_test):
    y_all = pd.concat(
        [pd.Series(y_train), pd.Series(y_valid), pd.Series(y_test)], axis=0
    )
    le = LabelEncoder()
    y_all_enc = le.fit_transform(y_all.astype(str))
    n_tr, n_va = len(y_train), len(y_valid)
    y_tr = y_all_enc[:n_tr]
    y_va = y_all_enc[n_tr : n_tr + n_va]
    y_te = y_all_enc[n_tr + n_va :]
    return y_tr, y_va, y_te, le.classes_


# ------------------------------
# XGBoost (version-agnostic ES)
# ------------------------------
def _fit_xgb_classifier(
    X_tr,
    y_tr,
    X_va,
    y_va,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds=50,
) -> xgb.XGBClassifier:
    n_classes = int(len(np.unique(y_tr)))
    spw = 1.0
    if n_classes == 2:
        pos = (y_tr == 1).sum()
        neg = (y_tr == 0).sum()
        spw = float(neg) / max(pos, 1) if pos > 0 else 1.0
    default = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="binary:logistic" if n_classes == 2 else "multi:softprob",
        eval_metric="auc" if n_classes == 2 else "mlogloss",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
        scale_pos_weight=spw,
    )
    if params:
        default.update(params)
    model = xgb.XGBClassifier(**default)
    eval_set = [(X_va, y_va)]
    fitted = False
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        model.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb)
        fitted = True
    except Exception:
        pass
    if not fitted:
        try:
            model.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
            )
            fitted = True
        except TypeError:
            pass
    if not fitted:
        model.fit(X_tr, y_tr, eval_set=eval_set)
    return model


def _fit_xgb_regressor(
    X_tr,
    y_tr,
    X_va,
    y_va,
    *,
    params: Dict[str, Any] | None = None,
    early_stopping_rounds=50,
) -> xgb.XGBRegressor:
    default = dict(
        n_estimators=2000,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        reg_alpha=0.0,
        objective="reg:squarederror",
        eval_metric="rmse",
        random_state=42,
        tree_method="hist",
        n_jobs=-1,
    )
    if params:
        default.update(params)
    model = xgb.XGBRegressor(**default)
    eval_set = [(X_va, y_va)]
    fitted = False
    try:
        from xgboost.callback import EarlyStopping

        cb = [
            EarlyStopping(
                rounds=early_stopping_rounds,
                metric_name=default["eval_metric"],
                data_name="validation_0",
                save_best=True,
            )
        ]
        model.fit(X_tr, y_tr, eval_set=eval_set, callbacks=cb)
        fitted = True
    except Exception:
        pass
    if not fitted:
        try:
            model.fit(
                X_tr,
                y_tr,
                eval_set=eval_set,
                early_stopping_rounds=early_stopping_rounds,
            )
            fitted = True
        except TypeError:
            pass
    if not fitted:
        model.fit(X_tr, y_tr, eval_set=eval_set)
    return model


def _best_n_estimators(model) -> int:
    n = getattr(model, "best_iteration", None)
    if n is not None:
        return int(n)
    try:
        return int(model.get_booster().best_ntree_limit)
    except Exception:
        return int(getattr(model, "n_estimators", 200))


def _refit_with_best_n(model, X, y, is_clf: bool):
    n_best = max(1, _best_n_estimators(model))
    params = model.get_params()
    params["n_estimators"] = n_best
    if is_clf:
        new = xgb.XGBClassifier(**params)
    else:
        new = xgb.XGBRegressor(**params)
    new.fit(X, y, verbose=False)
    return new


# ------------------------------
# Dataset loading
# ------------------------------
_FAMILIES = [("ADME", ADME), ("Tox", Tox), ("HTS", HTS)]


def load_single_pred_dataset(name: str):
    errors = []
    for fam_name, cls in _FAMILIES:
        try:
            ds = cls(name=name)
            return fam_name, ds
        except Exception as e:
            errors.append((fam_name, str(e)))
    raise ValueError(
        f"Dataset '{name}' not found in single_pred families. Tried: {errors}"
    )


# ------------------------------
# Bootstrap CI helpers
# ------------------------------
def _percentile_ci(samples: np.ndarray, ci: float) -> Tuple[float, float]:
    lo = np.percentile(samples, (1.0 - ci) / 2 * 100.0)
    hi = np.percentile(samples, (1.0 + ci) / 2 * 100.0)
    return float(lo), float(hi)


def _bootstrap_binary(
    y_true, y_prob, metric: str, n_boot: int, rng: np.random.Generator
):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    pos_idx = np.where(y_true == 1)[0]
    neg_idx = np.where(y_true == 0)[0]
    if pos_idx.size < 2 or neg_idx.size < 2:
        return np.array([np.nan])
    Np, Nn = pos_idx.size, neg_idx.size
    stats = []
    for _ in range(n_boot):
        s_pos = rng.choice(pos_idx, size=Np, replace=True)
        s_neg = rng.choice(neg_idx, size=Nn, replace=True)
        s = np.concatenate([s_pos, s_neg])
        yt = y_true[s]
        yp = y_prob[s]
        if metric == "AUROC":
            try:
                stats.append(roc_auc_score(yt, yp))
            except Exception:
                stats.append(np.nan)
        elif metric == "AUPRC":
            try:
                stats.append(average_precision_score(yt, yp))
            except Exception:
                stats.append(np.nan)
    return np.array(stats, dtype=float)


def _bootstrap_regression(
    y_true, y_pred, metric: str, n_boot: int, rng: np.random.Generator
):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    N = y_true.size
    stats = []
    for _ in range(n_boot):
        s = rng.choice(np.arange(N), size=N, replace=True)
        yt = y_true[s]
        yp = y_pred[s]
        if metric == "RMSE":
            stats.append(math.sqrt(mean_squared_error(yt, yp)))
        elif metric == "MAE":
            stats.append(mean_absolute_error(yt, yp))
    return np.array(stats, dtype=float)


# ------------------------------
# Feature sets & evaluation (single seed)
# ------------------------------
def build_feature_sets(
    tr_smiles,
    va_smiles,
    te_smiles,
    sae_kit,
    sae_layers: List[int],
    pca_dims: int,
    ecfp_bits: int,
    batch_size: int,
):
    # ECFP
    tr_ecfp = generate_ecfp_features(tr_smiles, n_bits=ecfp_bits)
    va_ecfp = generate_ecfp_features(va_smiles, n_bits=ecfp_bits)
    te_ecfp = generate_ecfp_features(te_smiles, n_bits=ecfp_bits)

    # SAE (concat layers); Dense from first layer
    dense_tr, sae_tr = concat_sae_across_layers(
        tr_smiles, sae_kit, sae_layers, batch_size=batch_size
    )
    dense_va, sae_va = concat_sae_across_layers(
        va_smiles, sae_kit, sae_layers, batch_size=batch_size
    )
    dense_te, sae_te = concat_sae_across_layers(
        te_smiles, sae_kit, sae_layers, batch_size=batch_size
    )

    # PCA on dense
    pca = PCA(n_components=pca_dims, random_state=42)
    tr_pca = pca.fit_transform(dense_tr)
    va_pca = pca.transform(dense_va)
    te_pca = pca.transform(dense_te)

    # Hybrid
    tr_hybrid = np.concatenate([sae_tr, tr_ecfp], axis=1)
    va_hybrid = np.concatenate([sae_va, va_ecfp], axis=1)
    te_hybrid = np.concatenate([sae_te, te_ecfp], axis=1)

    feature_sets = {
        "SAE Features": (sae_tr, sae_va, sae_te),
        "Transformer Embeddings": (dense_tr, dense_va, dense_te),
        "PCA on Embeddings": (tr_pca, va_pca, te_pca),
        "ECFP Fingerprints": (tr_ecfp, va_ecfp, te_ecfp),
        "SAE ⊕ ECFP": (tr_hybrid, va_hybrid, te_hybrid),
    }
    return feature_sets


def evaluate_single_label_one_seed(
    task_name: str,
    smi_col: str,
    split: Dict[str, pd.DataFrame],
    sae_kit,
    sae_layers,
    pca_dims,
    ecfp_bits,
    batch_size,
    early_stop,
    boot_reps,
    ci_level,
    seed_for_boot,
):
    train_df, valid_df, test_df = split["train"], split["valid"], split["test"]
    tr_smiles = train_df[smi_col].tolist()
    va_smiles = valid_df[smi_col].tolist()
    te_smiles = test_df[smi_col].tolist()

    y_tr, y_va, y_te = train_df["Y"].values, valid_df["Y"].values, test_df["Y"].values
    task_kind = infer_task_type_from_y(np.concatenate([y_tr, y_va, y_te]))
    feats = build_feature_sets(
        tr_smiles,
        va_smiles,
        te_smiles,
        sae_kit,
        sae_layers,
        pca_dims,
        ecfp_bits,
        batch_size,
    )

    rng = np.random.default_rng(seed_for_boot)
    rows = []

    for fname, (X_tr, X_va, X_te) in feats.items():
        if task_kind == "classification":
            ytr_enc, yva_enc, yte_enc, _ = prepare_labels_for_classification(
                y_tr, y_va, y_te
            )
            model = _fit_xgb_classifier(
                X_tr, ytr_enc, X_va, yva_enc, early_stopping_rounds=early_stop
            )
            model = _refit_with_best_n(
                model,
                np.vstack([X_tr, X_va]),
                np.concatenate([ytr_enc, yva_enc]),
                is_clf=True,
            )

            n_classes = len(np.unique(ytr_enc))
            if n_classes == 2:
                prob = model.predict_proba(X_te)[:, 1]
                # AUROC
                auroc = (
                    float(roc_auc_score(yte_enc, prob))
                    if len(np.unique(yte_enc)) > 1
                    else np.nan
                )
                auroc_samp = (
                    _bootstrap_binary(yte_enc, prob, "AUROC", boot_reps, rng)
                    if boot_reps > 0
                    else np.array([])
                )
                auroc_lo, auroc_hi = (
                    _percentile_ci(auroc_samp, ci_level)
                    if auroc_samp.size > 10
                    else (np.nan, np.nan)
                )
                rows.append(
                    {
                        "Task": task_name,
                        "Features": fname,
                        "Metric": "AUROC",
                        "Score": auroc,
                        "Boot_CI_low": auroc_lo,
                        "Boot_CI_high": auroc_hi,
                    }
                )
                # AUPRC
                auprc = float(average_precision_score(yte_enc, prob))
                auprc_samp = (
                    _bootstrap_binary(yte_enc, prob, "AUPRC", boot_reps, rng)
                    if boot_reps > 0
                    else np.array([])
                )
                auprc_lo, auprc_hi = (
                    _percentile_ci(auprc_samp, ci_level)
                    if auprc_samp.size > 10
                    else (np.nan, np.nan)
                )
                rows.append(
                    {
                        "Task": task_name,
                        "Features": fname,
                        "Metric": "AUPRC",
                        "Score": auprc,
                        "Boot_CI_low": auprc_lo,
                        "Boot_CI_high": auprc_hi,
                    }
                )
            else:
                prob = model.predict_proba(X_te)
                auroc = float(
                    roc_auc_score(yte_enc, prob, multi_class="ovr", average="macro")
                )
                # Bootstrap for macro-AUROC (simple, non-stratified)
                if boot_reps > 0:
                    N = len(yte_enc)
                    stats = []
                    for _ in range(boot_reps):
                        s = rng.choice(np.arange(N), size=N, replace=True)
                        stats.append(
                            roc_auc_score(
                                yte_enc[s], prob[s], multi_class="ovr", average="macro"
                            )
                        )
                    lo, hi = _percentile_ci(np.array(stats), ci_level)
                else:
                    lo, hi = (np.nan, np.nan)
                rows.append(
                    {
                        "Task": task_name,
                        "Features": fname,
                        "Metric": "AUROC_macro",
                        "Score": auroc,
                        "Boot_CI_low": lo,
                        "Boot_CI_high": hi,
                    }
                )
        else:
            model = _fit_xgb_regressor(
                X_tr, y_tr, X_va, y_va, early_stopping_rounds=early_stop
            )
            model = _refit_with_best_n(
                model,
                np.vstack([X_tr, X_va]),
                np.concatenate([y_tr, y_va]),
                is_clf=False,
            )
            preds = model.predict(X_te)
            rmse = float(math.sqrt(mean_squared_error(y_te, preds)))
            mae = float(mean_absolute_error(y_te, preds))

            rmse_samp = (
                _bootstrap_regression(y_te, preds, "RMSE", boot_reps, rng)
                if boot_reps > 0
                else np.array([])
            )
            mae_samp = (
                _bootstrap_regression(y_te, preds, "MAE", boot_reps, rng)
                if boot_reps > 0
                else np.array([])
            )
            rmse_lo, rmse_hi = (
                _percentile_ci(rmse_samp, ci_level)
                if rmse_samp.size > 10
                else (np.nan, np.nan)
            )
            mae_lo, mae_hi = (
                _percentile_ci(mae_samp, ci_level)
                if mae_samp.size > 10
                else (np.nan, np.nan)
            )

            rows.extend(
                [
                    {
                        "Task": task_name,
                        "Features": fname,
                        "Metric": "RMSE",
                        "Score": rmse,
                        "Boot_CI_low": rmse_lo,
                        "Boot_CI_high": rmse_hi,
                    },
                    {
                        "Task": task_name,
                        "Features": fname,
                        "Metric": "MAE",
                        "Score": mae,
                        "Boot_CI_low": mae_lo,
                        "Boot_CI_high": mae_hi,
                    },
                ]
            )
    return rows


# ------------------------------
# Runner (multi-seed + summaries)
# ------------------------------
def _t_interval(mean: float, std: float, n: int, ci: float) -> Tuple[float, float]:
    if n < 2 or not np.isfinite(std):
        return (np.nan, np.nan)
    alpha = 1.0 - ci
    q = student_t.ppf(1 - alpha / 2, df=n - 1)
    half = q * (std / math.sqrt(n))
    return (float(mean - half), float(mean + half))


def make_barplots(summary_df, metric, out_dir):
    if not _HAS_PLT:
        print("[plot] matplotlib is not available; skipping plots.")
        return
    sub = summary_df[summary_df["Metric"] == metric]
    if sub.empty:
        print(f"[plot] No rows for metric {metric}.")
        return
    tasks = sorted(sub["Task"].unique())
    features = [
        "SAE Features",
        "Transformer Embeddings",
        "PCA on Embeddings",
        "ECFP Fingerprints",
        "SAE ⊕ ECFP",
    ]
    features = [f for f in features if f in sub["Features"].unique()]

    # Plot per task
    for t in tasks:
        df = sub[sub["Task"] == t].set_index("Features").reindex(features)
        if df.empty:
            continue
        vals = df["Mean"].values
        yerr = np.vstack(
            [vals - df["Seed_CI_low"].values, df["Seed_CI_high"].values - vals]
        )
        plt.figure(figsize=(8, 4))
        plt.bar(np.arange(len(features)), vals, yerr=yerr, capsize=4)
        plt.xticks(np.arange(len(features)), features, rotation=25, ha="right")
        plt.ylabel(
            metric
            + (
                " (higher is better)"
                if metric in ("AUROC", "AUROC_macro", "AUPRC")
                else ""
            )
        )
        plt.title(f"{t} — {metric} with 95% CI (across seeds)")
        plt.tight_layout()
        os.makedirs(os.path.join(out_dir, "plots"), exist_ok=True)
        plt.savefig(os.path.join(out_dir, "plots", f"{t}_{metric}.png"), dpi=150)
        plt.close()


def run_panel(args):
    # Load LM + SAE
    print("Loading LM + SAE kit...")
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )
    sae_layers = [int(x) for x in args.sae_layers.split(",")]
    seeds = [int(s) for s in args.seeds.split(",") if s.strip()]

    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
    assert tasks, "No tasks provided."

    os.makedirs(args.out_dir, exist_ok=True)
    per_seed_rows: List[Dict[str, Any]] = []

    for name in tasks:
        print(f"\n=== {name} ===")
        fam, ds = load_single_pred_dataset(name)  # ADME / Tox / HTS

        # Try multi-seed scaffold splits
        for run_id, seed in enumerate(seeds, 1):
            print(f"  - Seed {seed} ({run_id}/{len(seeds)}) split: {args.split_method}")
            split = ds.get_split(
                method=args.split_method, seed=seed
            )  # scaffold split, seeded  :contentReference[oaicite:2]{index=2}

            # Find SMILES column
            train_df = split["train"]
            smi_col = _find_smiles_column(train_df)

            # Skip multi-label datasets
            if isinstance(train_df["Y"].iloc[0], (list, tuple, np.ndarray, dict)):
                print(f"    [WARN] {name} appears multi‑label. Skipping.")
                break

            rows = evaluate_single_label_one_seed(
                task_name=name,
                smi_col=smi_col,
                split=split,
                sae_kit=sae_kit,
                sae_layers=sae_layers,
                pca_dims=args.pca_dims,
                ecfp_bits=args.ecfp_bits,
                batch_size=args.batch_size,
                early_stop=args.early_stop,
                boot_reps=args.boot,
                ci_level=args.ci,
                seed_for_boot=seed + 12345,
            )
            for r in rows:
                r["Seed"] = seed
            per_seed_rows.extend(rows)

        # Save per-task, per-seed file
        df_task = pd.DataFrame([r for r in per_seed_rows if r["Task"] == name])
        df_task.to_csv(
            os.path.join(args.out_dir, f"{name.replace('/', '_')}_per_seed.csv"),
            index=False,
        )
        print("  Saved per-seed CSV (with bootstrap CIs).")

    # ---- Aggregate across seeds ----
    all_df = pd.DataFrame(per_seed_rows)
    if all_df.empty:
        print("No results collected. Check dataset names or multi‑label warning.")
        return

    # Mean and 95% t-interval across seeds
    agg = (
        all_df.groupby(["Task", "Features", "Metric"])["Score"]
        .agg(["mean", "std", "count"])
        .reset_index()
        .rename(columns={"mean": "Mean", "std": "Std", "count": "N"})
    )
    lows, highs = [], []
    for _, row in agg.iterrows():
        lo, hi = _t_interval(row["Mean"], row["Std"], int(row["N"]), args.ci)
        lows.append(lo)
        highs.append(hi)
    agg["Seed_CI_low"] = lows
    agg["Seed_CI_high"] = highs

    # Save and pretty-print
    agg.to_csv(os.path.join(args.out_dir, "summary_across_seeds.csv"), index=False)

    # Wide views per metric
    def _wide(metric_name):
        sub = agg[agg["Metric"] == metric_name]
        if sub.empty:
            return
        wide = sub.pivot(index="Task", columns="Features", values="Mean")
        order = [
            "SAE Features",
            "Transformer Embeddings",
            "PCA on Embeddings",
            "ECFP Fingerprints",
            "SAE ⊕ ECFP",
        ]
        wide = wide[[c for c in order if c in wide.columns]]
        print(f"\n=== Across-seed mean: {metric_name} ===")
        print(wide.to_string(float_format="%.4f"))
        wide.to_csv(os.path.join(args.out_dir, f"summary_mean_{metric_name}.csv"))

    for m in ["AUROC", "AUROC_macro", "AUPRC", "RMSE", "MAE"]:
        _wide(m)

    # Optional plots (per task × metric)
    if args.plot:
        make_barplots(agg, metric="AUROC", out_dir=args.out_dir)
        make_barplots(agg, metric="AUPRC", out_dir=args.out_dir)
        make_barplots(agg, metric="RMSE", out_dir=args.out_dir)


if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        "Run TDC single‑pred benchmarks with SAE / baselines + error bars"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", required=True)
    ap.add_argument("--sae_dir", required=True)
    ap.add_argument(
        "--sae_layers",
        type=str,
        default="5",
        help="Comma‑sep layer ids to concatenate in SAE (e.g., '5,8,10').",
    )
    ap.add_argument("--pca_dims", type=int, default=256)
    ap.add_argument("--ecfp_bits", type=int, default=2048)
    ap.add_argument("--batch_size", type=int, default=64)
    ap.add_argument("--early_stop", type=int, default=50)
    ap.add_argument(
        "--split_method",
        type=str,
        default="scaffold",
        help="TDC split (use 'scaffold' unless dataset provides a fixed one).",
    )
    ap.add_argument(
        "--tasks",
        type=str,
        required=True,
        help="Comma‑sep list, e.g., 'BBB_Martins,hERG,AMES'.",
    )
    ap.add_argument(
        "--seeds",
        type=str,
        default="1,2,3,4,5",
        help="Comma‑sep seeds for repeated splits.",
    )
    ap.add_argument(
        "--boot", type=int, default=500, help="Bootstrap reps per split (0 to disable)."
    )
    ap.add_argument(
        "--ci", type=float, default=0.95, help="CI level for bootstrap & across seeds."
    )
    ap.add_argument(
        "--plot",
        action="store_true",
        help="Make matplotlib bar charts with error bars.",
    )
    ap.add_argument("--out_dir", type=str, default="experiments/tdc_outputs")
    args = ap.parse_args()
    run_panel(args)
