import pandas as pd
import numpy as np
from scipy import sparse
import pandas as pd
import re
from typing import Tuple, Optional, Dict, Any, List
import inspect
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import Ridge
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.utils import check_random_state
from sklearn.neural_network import MLPRegressor
from packaging.version import Version
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from xgboost import XGBRegressor
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from time import perf_counter


# ---------------------------------------------------------
# Utilities 
# ---------------------------------------------------------
def attach_annotations(
    preds: pd.DataFrame,
    df: pd.DataFrame,
    x_test_index,
    DRUG_ANN_PATH: str = "Repurposing_Public_23Q2_Extended_Primary_Data_Matrix.csv",
    MODEL_PATH: str = "Model.csv",
    CELL_ANN_PATH: str = "Cell_lines_annotations_20181226.txt",
):
    if "DepMap_ID" not in preds.columns and "DepMap_ID" in df.columns:
        preds = preds.join(df.loc[x_test_index, ["DepMap_ID"]])

    if "DepMap_ID" in preds.columns:
        preds["DepMap_ID_norm"] = preds["DepMap_ID"].map(norm_depmap_id)

    cell_ann = pd.read_csv(CELL_ANN_PATH, sep=None, engine="python")
    id_col = pick_first(list(cell_ann.columns), ["depMapID","depmap_id","DepMap_ID","ModelID","model_id"])
    if id_col is None:
        raise ValueError(f"Could not find an ID column in {CELL_ANN_PATH}. Columns: {list(cell_ann.columns)[:20]}")
    cell_ann["depmap_norm"] = cell_ann[id_col].astype(str).map(norm_depmap_id)
    cell_ann = cell_ann.set_index("depmap_norm")

    rename_map = {
        "Pathology": "pathology",
        "Site_Primary": "primary_site",
        "Site_Subtype1": "site_subtype1",
        "Site_Subtype2": "site_subtype2",
        "Site_Subtype3": "site_subtype3",
        "Histology": "histology",
        "Hist_Subtype1": "hist_subtype1",
        "Hist_Subtype2": "hist_subtype2",
        "Hist_Subtype3": "hist_subtype3",
        "Disease": "primary_disease",
        "Gender": "sex",
        "Age": "age",
        "Race": "race",
        "tcga_code": "tcga_code",
        "lineage": "lineage",
        "lineage_subtype": "lineage_subtype",
    }
    keep_cols = [c for c in rename_map if c in cell_ann.columns]
    cell_meta = cell_ann[keep_cols].rename(columns=rename_map)

    # ---- Model.csv (extra info) ----
    try:
        model = pd.read_csv(MODEL_PATH)
        mdl_id_col = pick_first(list(model.columns), ["ModelID","DepMap_ID","ModelId","model_id","depmap_id"])
        if mdl_id_col is not None:
            model[mdl_id_col] = model[mdl_id_col].astype(str).map(norm_depmap_id)
            model = model.set_index(mdl_id_col)
    except Exception:
        pass

    if "DepMap_ID_norm" in preds.columns:
        to_add = [c for c in cell_meta.columns if c not in preds.columns]
        preds = preds.join(cell_meta[to_add], on="DepMap_ID_norm")
    else:
        print("Warning: preds has no DepMap_ID_norm; cannot attach cell annotations.")

    print("Attached from cell_ann:", [rename_map[c] for c in keep_cols])

    # ---- Drug annotations ----
    drug_ann = pd.read_csv(DRUG_ANN_PATH, sep=None, engine="python")
    name_col = pick_first(list(drug_ann.columns), ["pert_iname","compound","name","compound_name","broad_name"])
    id_col   = pick_first(list(drug_ann.columns), ["broad_id","pert_id","repurposing_id","repurposing_id_with_data"])
    ann_cols = [c for c in ["moa","moa_category","targets","target","target_or_pathway","smiles"] if c in drug_ann.columns]

    if "compound" in preds.columns and id_col and preds["compound"].isin(drug_ann[id_col]).any():
        preds = preds.merge(drug_ann[[id_col]+ann_cols].rename(columns={id_col:"compound"}), on="compound", how="left")
    elif "compound" in preds.columns and name_col:
        preds["compound_norm"] = preds["compound"].astype(str).str.strip()
        drug_ann[name_col] = drug_ann[name_col].astype(str).str.strip()
        preds = preds.merge(
            drug_ann[[name_col]+ann_cols].rename(columns={name_col:"compound_norm"}),
            on="compound_norm", how="left"
        ).drop(columns=["compound_norm"])
    else:
        print("Drug annotations provided but could not match on ID or name; check column headers.")

    added = [c for c in ["lineage","primary_disease","primary_site","lineage_subtype","tcga_code","sex","age","moa","targets","target_or_pathway"] if c in preds.columns]
    print("Attached columns now in preds:", added)

    return preds
    
    
def pick_first(columns: List[str], candidates: List[str]) -> Optional[str]:
    cols_lower = {c.lower(): c for c in columns}
    for cand in candidates:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]
    return None

def get_model_cols(preds_df: pd.DataFrame, prefix: str = "pred_") -> List[str]:
    return [c for c in preds_df.columns if str(c).startswith(prefix)]

def norm_depmap_id(x: str) -> str:
    s = str(x).strip().upper()
    if s.startswith("ACH-"):
        digits = re.sub(r"\D", "", s.split("-", 1)[1])
        return f"ACH-{int(digits):06d}" if digits else s
    m = re.fullmatch(r"(\d+)", s)
    if m: return f"ACH-{int(m.group(1)):06d}"
    m = re.fullmatch(r"ACH0*?(\d+)", s)
    if m: return f"ACH-{int(m.group(1)):06d}"
    return s

# --------------
# PRISM loader 
# --------------

def load_prism_auto(PRISM_PATH: str) -> Tuple[pd.DataFrame, str]:
    df = pd.read_csv(PRISM_PATH)
    compound_col = df.columns[0]
    ach_cols = list(df.columns[1:])

    tmp = df.copy()
    tmp[compound_col] = tmp[compound_col].astype(str)
    for c in ach_cols:
        tmp[c] = pd.to_numeric(tmp[c], errors="coerce")

    long = (
        tmp.set_index(compound_col)[ach_cols]
           .stack()
           .reset_index()
           .rename(columns={compound_col: "compound", "level_1": "DepMap_ID", 0: "response"})
    )
    long["DepMap_ID"] = long["DepMap_ID"].astype(str).str.strip().str.upper().map(norm_depmap_id)
    long = long.dropna(subset=["DepMap_ID", "compound", "response"]).reset_index(drop=True)

    v = pd.to_numeric(long["response"], errors="coerce")
    v = v[np.isfinite(v)]

    q1, q99 = np.percentile(v, [1, 99])
    in_0_100 = (v.between(0, 100)).mean()
    in_0_1   = (v.between(0, 1)).mean()
    core_in_0_100 = (v.between(max(0, q1-5), min(100, q99+5))).mean()  

    if in_0_1 > 0.85:
        scale = "viability_0_1"
    elif (in_0_100 > 0.85) or (core_in_0_100 > 0.95):
        scale = "viability_pct"
    else:
        scale = "lfc"

    return long, scale


# ---------------------------------------------------------
# Drug ranking utilities
# ---------------------------------------------------------

def _rank_drugs_by(resp_long: pd.DataFrame,
                   criterion: str = "count",
                   min_per_drug: int = 80) -> pd.DataFrame:
    df = resp_long.dropna(subset=["compound", "y"]).copy()
    df = df[np.isfinite(df["y"])]
    g = df.groupby("compound")["y"]
    stats = pd.DataFrame({
        "count": g.size(),
        "mean_y": g.mean(),
        "mean_abs_y": g.apply(lambda x: np.mean(np.abs(x))),
        "median_y": g.median(),
        "std_y": g.std(ddof=0),
        "iqr_y": g.quantile(0.75) - g.quantile(0.25),
        "q95_y": g.quantile(0.95),
        "q05_y": g.quantile(0.05),
    })
    stats = stats[stats["count"] >= int(min_per_drug)]
    stats["score"] = stats[criterion]
    return stats.sort_values("score", ascending=False)

def _rank_drugs_by_site_heterogeneity(df: pd.DataFrame, site_col: str,
                                      min_per_drug: int = 50, min_sites: int = 3) -> pd.DataFrame:
    
    df = df.dropna(subset=[site_col, "compound", "y"]).copy()
    g = df.groupby('compound')
    stats = g['y'].agg(count='size', mean='mean', var='var')

    def between_var(sub):
        by = sub.groupby(site_col)['y']
        means = by.mean()
        counts = by.size()
        mu = sub['y'].mean()
        w = counts / counts.sum()
        n_sites_ge5 = int((counts >= 5).sum())
        return float(((means - mu) ** 2 * w).sum()), n_sites_ge5

    bv, nsites = zip(*g.apply(lambda d: between_var(d)))
    stats['between_site_var'] = np.array(bv, dtype=float)
    stats['n_sites_ge5'] = np.array(nsites, dtype=int)
    stats['hetero_score'] = stats['between_site_var'] * np.log1p(stats['n_sites_ge5'])

    sel = stats[(stats['count'] >= min_per_drug) & (stats['n_sites_ge5'] >= min_sites)]
    return sel.sort_values('hetero_score', ascending=False)

def _site_union_top_var_genes(df_tr: pd.DataFrame, gene_cols: List[str], site_col: str,
                              k_per_site: int = 60, k_total: int = 100) -> List[str]:
    keep_sets = []
    for site, sub in df_tr.groupby(site_col):
        X = sub[gene_cols].apply(pd.to_numeric, errors='coerce').astype(np.float32)
        vars_arr = np.nanvar(X.to_numpy(dtype=np.float32, copy=False), axis=0)
        topk = pd.Series(vars_arr, index=gene_cols).nlargest(min(k_per_site, len(gene_cols))).index
        keep_sets.append(set(topk))
    keep = set().union(*keep_sets)
    Xall = df_tr[list(keep)].apply(pd.to_numeric, errors='coerce').astype(np.float32)
    vars_all = np.nanvar(Xall.to_numpy(dtype=np.float32, copy=False), axis=0)
    order = pd.Series(vars_all, index=Xall.columns).nlargest(min(k_total, len(keep))).index
    return list(order)

# ---------------------------------------------------------
# Metadata prep and site cleaning
# ---------------------------------------------------------

def _clean_site_label(s):
    if pd.isna(s):
        return None
    s = str(s).replace("\u00A0", " ")      
    s = " ".join(s.strip().split())        
    s = s.strip(",:;")
    up = s.upper()
    if up in {"UNABLE TO CLASSIFY", "UNKNOWN", "N/A", "NA", "NOT AVAILABLE", "NONE", ""}:
        return None
    return s

def _pick_site_column(cm: pd.DataFrame, candidates_in_order=None) -> tuple[str, pd.Series]:
    if candidates_in_order is None:
        candidates_in_order = [
            "Disease", "primary_site", "Site_Primary",
            "lineage", "lineage_subtype",
            "OncotreePrimaryDisease", "OncotreePrimaryTissue", "oncotree_code", "tcga_code",
        ]
    cands = [c for c in candidates_in_order if c in cm.columns]
    if not cands:
        raise ValueError(f"No site-like column in metadata. Columns: {list(cm.columns)[:25]}")

    best_col, best_valid, best_clean = None, -1, None
    for c in cands:
        cleaned = cm[c].map(_clean_site_label)
        valid = int(cleaned.notna().sum())
        if valid > best_valid:
            best_col, best_valid, best_clean = c, valid, cleaned

    print(f"[cell meta] site candidates tried: {cands} | picked: '{best_col}' with {best_valid} valid labels out of {len(cm)} rows")
    print(f"[cell meta] preview cleaned labels (first 10): {list(best_clean.dropna().astype(str).unique()[:10])}")
    return best_col, best_clean

def _prep_cell_meta(path: str, site_col: str = "primary_site") -> pd.DataFrame:
    try:
        cm = pd.read_csv(path, sep=None, engine="python")
    except Exception:
        try:
            cm = pd.read_csv(path, sep="\t")
        except Exception:
            cm = pd.read_csv(path)

    lowered = {c.lower().replace("_",""): c for c in cm.columns}
    id_col = None
    for key in ("depmapid","depmap id","modelid","model id","cellline","cellid","sampleid"):
        if key in lowered:
            id_col = lowered[key]; break
    if id_col is None:
        id_col = "DepMap_ID" if "DepMap_ID" in cm.columns else None
    if id_col is None:
        raise ValueError(f"Could not find DepMap ID column in CELL_META_PATH. Columns: {list(cm.columns)[:25]}")

    cm = cm.rename(columns={id_col: "DepMap_ID"})
    cm["DepMap_ID"] = cm["DepMap_ID"].map(norm_depmap_id)

    _, cleaned = _pick_site_column(cm)
    out = pd.DataFrame({"DepMap_ID": cm["DepMap_ID"], site_col: cleaned})
    out = out.dropna(subset=["DepMap_ID"]).drop_duplicates("DepMap_ID", keep="first")
    return out

def _try_model_bridge(MODEL_PATH: str, site_col: str = "primary_site") -> Optional[pd.DataFrame]:
    try:
        md = pd.read_csv(MODEL_PATH, sep=None, engine="python")
    except Exception:
        try:
            md = pd.read_csv(MODEL_PATH, sep="\t")
        except Exception:
            md = pd.read_csv(MODEL_PATH)
    id_col = pick_first(list(md.columns), ["DepMap_ID","ModelID","model_id","depmap_id","ModelId"])
    if id_col is None:
        print(f"[model bridge] no DepMap-like ID in {MODEL_PATH}; cols: {list(md.columns)[:25]}")
        return None
    md[id_col] = md[id_col].astype(str).map(norm_depmap_id)
    site_cands = [
        "OncotreePrimaryDisease","OncotreePrimaryTissue","primary_site","Disease",
        "lineage","lineage_subtype","oncotree_code","OncotreeCode","tcga_code"
    ]
    present = [c for c in site_cands if c in md.columns]
    if not present:
        print(f"[model bridge] no site-ish columns found in {MODEL_PATH}; cols: {list(md.columns)[:25]}")
        return None
    chosen_raw, cleaned = _pick_site_column(md, present)
    out = pd.DataFrame({"DepMap_ID": md[id_col].astype(str).map(norm_depmap_id), site_col: cleaned})
    out = out.dropna(subset=["DepMap_ID", site_col]).drop_duplicates("DepMap_ID", keep="first")
    return out

# ---------------------------------------------------------
# Dataset preparation (main)
# ---------------------------------------------------------

def load_cancer_data(
    PRISM_PATH: str,
    EXPR_PATH: str,
    TOP_K_DRUGS: int,
    CRITERION: str,              # "count" | "coverage" | "site_heterogeneity"
    MIN_PER_DRUG: int,
    PER_DRUG_CAP: Optional[int],
    MAX_GENES: int,
    TEST_SIZE: float,
    RSEED: int,
    *,
    CELL_META_PATH: Optional[str] = None,
    MODEL_PATH: Optional[str] = None,        
    SITE_COL: str = "primary_site",
    COVERAGE_SITE_MIN: int = 20,
    PER_SITE_CAP: Optional[int] = None,                
    PER_SITE_FRACTION_CAP: Optional[float] = None,     
    USE_PAIR_STRAT: bool = True,                       
    USE_SITE_UNION_GENES: bool = False,               
    K_PER_SITE: int = 60                              
):
    rng = check_random_state(RSEED)

    resp_long, scale = load_prism_auto(PRISM_PATH)

    expr = pd.read_csv(EXPR_PATH)
    expr = expr.set_index("DepMap_ID") if "DepMap_ID" in expr.columns else expr.set_index(expr.columns[0])
    expr.index = expr.index.map(norm_depmap_id)
    expr.index.name = "DepMap_ID"
    expr = expr.apply(pd.to_numeric, errors="coerce").astype(np.float32)
    expr = expr.apply(np.log1p)
    expr = expr.apply(lambda x: (x - x.mean())/x.std(ddof=0), axis=0)
    gene_cols = expr.columns.tolist()

    resp_long["DepMap_ID"] = resp_long["DepMap_ID"].map(norm_depmap_id)
    resp_long = resp_long.dropna(subset=["DepMap_ID", "compound", "response"])
    resp_long = resp_long[resp_long["DepMap_ID"].isin(expr.index)].copy()

    v = pd.to_numeric(resp_long["response"], errors="coerce")
    if scale == "viability_0_1":
        resp_long["y"] = 1.0 - v
    elif scale == "viability_pct":
        resp_long["y"] = 1.0 - (v / 100.0)
    elif scale == "lfc":
        v_clip = v.clip(lower=-6, upper=6)
        resp_long["y"] = -v_clip
    else:
        resp_long["y"] = (1.0 - v) if v.between(0, 1).mean() > 0.9 else (-v)
    resp_long = resp_long.dropna(subset=["y"])
    
    print(resp_long["y"].describe())
    print(resp_long.groupby("compound")["y"].mean().head())

    if CELL_META_PATH is not None:
        cell_meta = _prep_cell_meta(CELL_META_PATH, site_col=SITE_COL)

        ids_resp = set(resp_long["DepMap_ID"].unique())
        ids_meta = set(cell_meta["DepMap_ID"].unique())
        overlap = len(ids_resp & ids_meta)
        print(f"[site merge] ID overlap (PRISM vs meta): {overlap} of {len(ids_resp)} PRISM and {len(ids_meta)} meta")

        before = len(resp_long)
        resp_long = resp_long.merge(cell_meta, on="DepMap_ID", how="left")
        resp_long[SITE_COL] = resp_long[SITE_COL].map(_clean_site_label)
        matched = resp_long[SITE_COL].notna().sum()
        print(f"[site merge] from CELL_META_PATH: {matched}/{before} ({100*matched/max(1,before):.1f}%) non-unknown labels")

        if matched < max(25, int(0.05 * before)) and MODEL_PATH is not None:
            print("[site merge] few usable site labels; trying MODEL bridge...")
            bridge = _try_model_bridge(MODEL_PATH, site_col=SITE_COL)
            if bridge is not None:
                missing_mask = resp_long[SITE_COL].isna()
                resp_long = resp_long.merge(
                    bridge.rename(columns={SITE_COL: f"{SITE_COL}__bridge"}),
                    on="DepMap_ID", how="left"
                )
                fill = resp_long[f"{SITE_COL}__bridge"].map(_clean_site_label)
                resp_long.loc[missing_mask, SITE_COL] = fill[missing_mask]
                matched2 = resp_long[SITE_COL].notna().sum()
                print(f"[site merge] after MODEL bridge: {matched2}/{before} ({100*matched2/max(1,before):.1f}%) non-unknown labels")
                resp_long = resp_long.drop(columns=[f"{SITE_COL}__bridge"], errors="ignore")
            else:
                print("[site merge] MODEL bridge not usable.")
    else:
        found = pick_first(list(resp_long.columns),
                           [SITE_COL,"primary_site","Site_Primary","Disease",
                            "primary_disease","lineage","lineage_subtype","oncotree_code","tcga_code"])
        if found is None:
            resp_long[SITE_COL] = np.nan
            print("[site merge] no metadata & no site-like field in PRISM table.")
        elif found != SITE_COL:
            resp_long = resp_long.rename(columns={found: SITE_COL})
        resp_long[SITE_COL] = resp_long[SITE_COL].map(_clean_site_label)

    print("[site] preview (first 10 unique cleaned labels):",
          list(resp_long[SITE_COL].dropna().astype(str).unique()[:10]))
    print("[site] non-null cleaned site rows:",
          int(resp_long[SITE_COL].notna().sum()), "/", len(resp_long))

    resp_long = resp_long[resp_long[SITE_COL].notna()].copy()
    if resp_long.empty:
        raise ValueError(
            "After dropping unknown sites, no rows remain.\n"
            "Diagnostics: wrong ID namespace or the metadata's site column is unusable.\n"
            "Suggestions: (a) pass MODEL_PATH to bridge IDs and site labels, "
            "(b) verify CELL_META_PATH has 'Disease' or 'Site_Primary' populated, "
            "(c) temporarily do not drop unknowns to inspect raw labels."
        )

    # Rank & pick drugs
    crit = CRITERION.lower()
    if crit == "coverage":
        cts = resp_long.groupby(["compound", SITE_COL]).size().reset_index(name="n")
        coverage = (cts[cts["n"] >= COVERAGE_SITE_MIN]
                    .groupby("compound")[SITE_COL].nunique()
                    .sort_values(ascending=False))
        picked = coverage.head(TOP_K_DRUGS).index
    elif crit == "site_heterogeneity":
        ranked = _rank_drugs_by_site_heterogeneity(resp_long, SITE_COL, MIN_PER_DRUG, min_sites=3)
        picked = ranked.head(TOP_K_DRUGS).index
    else:  # "count" or any other metric in _rank_drugs_by
        stats = _rank_drugs_by(resp_long, criterion="count", min_per_drug=MIN_PER_DRUG)
        picked = stats.head(TOP_K_DRUGS).index

    resp_sub = resp_long[resp_long["compound"].isin(picked)].copy()

    if PER_DRUG_CAP is not None:
        resp_sub = (resp_sub.groupby("compound", group_keys=False)
                    .apply(lambda g: g.sample(n=min(PER_DRUG_CAP, len(g)), random_state=RSEED))
                    .reset_index(drop=True))

    if (PER_SITE_CAP is not None) or (PER_SITE_FRACTION_CAP is not None):
        sizes = resp_sub.groupby(SITE_COL).size()
        if PER_SITE_FRACTION_CAP is not None:
            cap = int(np.floor(PER_SITE_FRACTION_CAP * np.median(sizes.values)))
        else:
            cap = int(PER_SITE_CAP)
        def _cap_site(g):
            return g.sample(n=min(len(g), cap), random_state=RSEED)
        resp_sub = (resp_sub.groupby(SITE_COL, group_keys=False)
                    .apply(_cap_site).reset_index(drop=True))

    print(f"Picked {len(picked)} drugs via '{CRITERION}'; rows kept (known sites only): {len(resp_sub)}")

    if resp_sub[SITE_COL].nunique() < 2:
        print("[warn] Only one site remains after filtering; consider relaxing caps/coverage.")

    df = resp_sub.merge(expr, left_on="DepMap_ID", right_index=True, how="inner")
    df["compound"] = df["compound"].astype("category")

    if USE_PAIR_STRAT:
        pair = df[[SITE_COL, "compound"]].astype(str).agg("::".join, axis=1)
        vc = pair.value_counts()
        keep_pairs = vc[vc >= 2].index
        mask = pair.isin(keep_pairs)
        if mask.sum() >= 2 and len(keep_pairs) > 1:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=RSEED)
            tr_idx, te_idx = next(sss.split(df[mask], pair[mask]))
            print("Unique pairs total:", pair.nunique())
            print("Train pairs:", pair[mask].iloc[tr_idx].nunique(),
                  "Test pairs:", pair[mask].iloc[te_idx].nunique())
            df_tr = df[mask].iloc[tr_idx].copy()
            df_te = df[mask].iloc[te_idx].copy()
            leftovers = df[~mask]
            if not leftovers.empty:
                df_tr = pd.concat([df_tr, leftovers], axis=0, ignore_index=True)
            print(f"[split] stratified by (site,drug) pairs; kept {len(keep_pairs)} pairs")
        else:
            strat_series = df[SITE_COL]
            if strat_series.nunique() > 1 and strat_series.value_counts().min() >= 2:
                sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=RSEED)
                tr_idx, te_idx = next(sss.split(df, strat_series))
                df_tr = df.iloc[tr_idx].copy()
                df_te = df.iloc[te_idx].copy()
                print(f"[split] stratified by '{SITE_COL}' with {strat_series.nunique()} sites")
            else:
                df_tr, df_te = train_test_split(df, test_size=TEST_SIZE, random_state=RSEED)
                print("[split] fallback random split (insufficient site diversity)")
    else:
        strat_series = df[SITE_COL]
        if strat_series.nunique() > 1 and strat_series.value_counts().min() >= 2:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=RSEED)
            tr_idx, te_idx = next(sss.split(df, strat_series))
            df_tr = df.iloc[tr_idx].copy()
            df_te = df.iloc[te_idx].copy()
            print(f"[split] stratified by '{SITE_COL}' with {strat_series.nunique()} sites")
        else:
            df_tr, df_te = train_test_split(df, test_size=TEST_SIZE, random_state=RSEED)
            print("[split] fallback random split (insufficient site diversity)")

    if USE_SITE_UNION_GENES:
        keep = _site_union_top_var_genes(df_tr, gene_cols, SITE_COL, k_per_site=K_PER_SITE, k_total=MAX_GENES)
    else:
        Xtr_genes_full = df_tr[gene_cols].apply(pd.to_numeric, errors="coerce").astype(np.float32)

        # standardize per-gene 
        mu = Xtr_genes_full.mean(axis=0)
        sd = Xtr_genes_full.std(axis=0).replace(0.0, 1.0)
        Xtr_std = (Xtr_genes_full - mu) / sd
        
        vars_arr = np.nanvar(Xtr_std.to_numpy(dtype=np.float32, copy=False), axis=0)
        vars_ = pd.Series(vars_arr, index=gene_cols).sort_values(ascending=False)
        keep = vars_.index[:min(MAX_GENES, len(gene_cols))].tolist()

    def _make_X(d: pd.DataFrame) -> pd.DataFrame:
        Xg = d[keep].apply(pd.to_numeric, errors="coerce").astype(np.float32)
        return pd.concat([Xg, d["compound"].rename("compound")], axis=1)

    x_train = _make_X(df_tr)
    x_test  = _make_X(df_te)
    
    x_train = coerce_numeric_like_strings(x_train)
    x_test  = coerce_numeric_like_strings(x_test)

    y_train = df_tr["y"].astype(np.float32)
    y_test  = df_te["y"].astype(np.float32)

    meta_df = df_te[["DepMap_ID", "compound", SITE_COL]].copy()

    return x_train, y_train, x_test, y_test, meta_df, df



# ---------------------------------------------------------
# Baselines & training
# ---------------------------------------------------------

def _ohe_kwargs_dense():
    kw = {"handle_unknown": "ignore"}
    if Version(sklearn.__version__) >= Version("1.2"):
        kw["sparse_output"] = False
    else:
        kw["sparse"] = False
    return kw

def _ohe_kwargs_sparse():
    kw = {"handle_unknown": "ignore"}
    if Version(sklearn.__version__) >= Version("1.2"):
        kw["sparse_output"] = True
    else:
        kw["sparse"] = True
    return kw

def build_preprocessors(genes, compound_col="compound"):
    genes = [g for g in list(genes) if g != compound_col]

    pre_linear = ColumnTransformer(
        transformers=[
            ("num", Pipeline([
                ("impute", SimpleImputer(strategy="median")),
                ("scale",  StandardScaler(with_mean=True)),
            ]), list(genes)),
            ("drug", OneHotEncoder(**_ohe_kwargs_dense()), [compound_col]),
        ],
        sparse_threshold=0.0,
        remainder="drop",
        n_jobs=None
    )

    pre_tree = ColumnTransformer(
        transformers=[
            ("num", Pipeline([
                ("impute", SimpleImputer(strategy="median")),
            ]), list(genes)),
            ("drug", OneHotEncoder(**_ohe_kwargs_sparse()), [compound_col]),
        ],
        sparse_threshold=1.0,
        remainder="drop",
        n_jobs=None
    )
    return pre_linear, pre_tree

    
def split_cols_strict(X: pd.DataFrame):
    num_cols = [c for c in X.columns if pd.api.types.is_numeric_dtype(X[c])]
    cat_cols = [c for c in X.columns if c not in num_cols]
    return num_cols, cat_cols

def coerce_numeric_like_strings(X: pd.DataFrame, frac_threshold=0.9):
    X = X.copy()
    for c in X.columns:
        if X[c].dtype == object:
            z = pd.to_numeric(X[c], errors="coerce")
            if z.notna().mean() >= frac_threshold:
                X[c] = z
    return X

def infer_gene_keep(X_df):
    return [c for c in X_df.columns if c != "compound"]


class ToDense(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None): return self
    def transform(self, X):
        return X.toarray() if hasattr(X, "toarray") else X

def sanitize_sparse(X):
    if sparse.issparse(X):
        X = X.tocsr(copy=True)
        d = X.data
        bad_inf = ~np.isfinite(d)
        if bad_inf.any(): d[bad_inf] = np.nan
        bad_nan = np.isnan(d)
        if bad_nan.any(): d[bad_nan] = 0.0
        return X
    else:
        return np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

def eval_model(name, estimator, x_train, y_train, x_test, y_test, pre, verbose=0, use_es=False, es_rounds=100, es_verbose=50):
    if hasattr(estimator, "verbosity") and verbose is not None:
        try:
            estimator.set_params(verbosity=verbose)
        except Exception:
            pass

    steps = [("pre", pre)]
    if isinstance(estimator, HistGradientBoostingRegressor):
        steps.append(("to_dense", ToDense()))   
    steps.append(("model", estimator))
    pipe = Pipeline(steps)

    fit_kwargs = {}
    if use_es and isinstance(estimator, XGBRegressor):
        fit_kwargs = {
            "model__eval_set": [(x_test, y_test)],
            "model__eval_metric": "rmse",
            "model__early_stopping_rounds": es_rounds,
            "model__verbose": es_verbose,
        }

    t0 = perf_counter()
    pipe.fit(x_train, y_train, **fit_kwargs)
    tr_time = perf_counter() - t0

    pred = pipe.predict(x_test)
    r2   = r2_score(y_test, pred)
    rmse = mean_squared_error(y_test, pred)   
    mae  = mean_absolute_error(y_test, pred)

    best_info = ""
    if isinstance(estimator, XGBRegressor):
        mdl = pipe.named_steps["model"]
        if hasattr(mdl, "best_iteration") and mdl.best_iteration is not None:
            best_info = f" | best_iter: {mdl.best_iteration}"

    print(f"\n--- {name} ---")
    print(f"Train time: {tr_time:.2f}s  |  RÃƒâ€šÃ‚Â²: {r2:.3f}  |  RMSE: {rmse:.4f}  |  MAE: {mae:.4f}{best_info}")
    return pipe, {"name": name, "R2": r2, "RMSE": rmse, "MAE": mae, "train_time_s": tr_time}

# ------------------------------------------------------------
# Predictors
# ------------------------------------------------------------

def train_models(x_train, y_train, x_test, y_test, components_cap=320, RSEED=42):
    """
    Trains Ridge, HGB, XGB, and MLP with appropriate preprocessors.
    Ridge/MLP: dense, centered features.
    Trees:  impute + sparse OHE.
    """
    results = []

    pre_linear, pre_tree = build_preprocessors(infer_gene_keep(x_train))

    # --- Ridge (dense, centered) ---
    ridge = Ridge(alpha=0.5, random_state=RSEED)
    pipe_ridge, m_ridge = eval_model(
        "Ridge (L2)", ridge, x_train, y_train, x_test, y_test, pre_linear, verbose=0
    )
    results.append(m_ridge)

    # --- HistGradientBoosting (light pre) ---
    hgb = HistGradientBoostingRegressor(max_iter=500, learning_rate=0.05,
                                        max_depth=7, random_state=RSEED)
    pipe_hgb, m_hgb = eval_model(
        "HistGradientBoosting", hgb, x_train, y_train, x_test, y_test, pre_tree, verbose=1
    )
    results.append(m_hgb)

    # --- XGBRegressor (light pre) ---
    xgb = XGBRegressor(objective="reg:squarederror", n_estimators=400, max_depth=8,
                       subsample=0.9, colsample_bytree=0.8, learning_rate=0.05,
                       reg_alpha=1e-3, reg_lambda=1.0,
                       tree_method="hist", random_state=RSEED, n_jobs=-1)
    pipe_xgb, m_xgb = eval_model(
        "XGBRegressor", xgb, x_train, y_train, x_test, y_test, pre_tree,
        verbose=0, use_es=False
    )
    results.append(m_xgb)

    # --- MLP (dense, centered) ---
    nn = MLPRegressor(hidden_layer_sizes=(128,64), learning_rate_init=1e-3,
                      batch_size=64, max_iter=250, early_stopping=True,
                      validation_fraction=0.1, n_iter_no_change=5,
                      random_state=RSEED, verbose=False)
    pipe_mlp, m_mlp = eval_model(
        "MLP (dense, centered)", nn, x_train, y_train, x_test, y_test, pre_linear, verbose=0
    )
    results.append(m_mlp)

    fitted = {
        "ridge": pipe_ridge,
        "hgb":   pipe_hgb,
        "xgb":   pipe_xgb,
        "mlp":   pipe_mlp,
        "_pre_linear": pre_linear,
        "_pre_tree":   pre_tree,
    }

    return fitted, results
    
def extract_mlp_params_from_pipeline(mlp_pipe):
    mlp_est = None
    for name, step in mlp_pipe.named_steps.items():
        if isinstance(step, MLPRegressor):
            mlp_est = step
            break
    if mlp_est is None:
        raise ValueError("Pipeline 'mlp' does not contain an MLPRegressor step.")

    valid_keys = set(inspect.signature(MLPRegressor.__init__).parameters) - {"self"}
    raw = mlp_est.get_params(deep=False) 
    return {k: v for k, v in raw.items() if k in valid_keys}

def mlp_predict(fitted_models, X):
    return fitted_models["mlp"].predict(X)
    
def build_gate_features(fitted_models, X):

    pipe = fitted_models["mlp"]
    pre = None
    for key in ("preprocessor", "pre", "transform"):
        if key in getattr(pipe, "named_steps", {}):
            pre = pipe.named_steps[key]
            break
    if pre is None:
        for step in getattr(pipe, "named_steps", {}).values():
            if hasattr(step, "transform"):
                pre = step
                break
    if pre is None:
        pre = fitted_models.get("_pre_linear", None)
    if pre is None:
        raise KeyError("No preprocessor found in 'mlp' pipeline and no '_pre_linear' in fitted_models.")

    Xr = pre.transform(X)
    if hasattr(Xr, "toarray"):
        Xr = Xr.toarray()
    return np.asarray(Xr, dtype=np.float32)
    
# ---------------------------------------------------------
# Prediction table & plots 
# ---------------------------------------------------------

def predict_table(fitted_models: Dict[str, Any],
                  X_test: pd.DataFrame,
                  y_test: pd.Series,
                  meta: Optional[pd.DataFrame]) -> pd.DataFrame:
    preds = pd.DataFrame({f"pred_{k}": m.predict(X_test) for k, m in fitted_models.items()},
                         index=X_test.index)
    preds.insert(0, "true", y_test.values)
    if meta is not None:
        for c in meta.columns:
            preds[c] = meta.loc[preds.index, c].values
    pred_cols = [c for c in preds.columns if c.startswith("pred_")]
    abs_err = pd.concat([(preds["true"] - preds[c]).abs().rename(c) for c in pred_cols], axis=1)
    preds["winner"] = abs_err.idxmin(axis=1).str.replace("^pred_", "", regex=True)
    return preds

def winner_summary(preds: pd.DataFrame, group_col: str, top_n: int = 10) -> pd.DataFrame:
    tbl = (preds.groupby(["winner", group_col]).size().reset_index(name="count"))
    tbl["frac_within_winner"] = tbl.groupby("winner")["count"].transform(lambda x: x / x.sum())
    return tbl.sort_values(["winner","frac_within_winner"], ascending=[True, False]).groupby("winner").head(top_n)

def plot_winner_enrichment(preds: pd.DataFrame, group_col: str, top_n: int = 8):
    tbl = winner_summary(preds, group_col, top_n=top_n)
    g = sns.catplot(
        data=tbl, x="frac_within_winner", y=group_col,
        hue="winner", kind="bar", height=6, aspect=1.6
    )
    g.set_axis_labels("Fraction within winner group", group_col)
    g.set_titles(f"Top {top_n} {group_col} where each model wins")
    try:
        g.set_ylabels("Primary site")            
    except AttributeError:
        for ax in g.axes.flat:                  
            ax.set_ylabel("Primary site")
    plt.show()

def metrics_by_region(X_test: pd.DataFrame,
                       y_test: pd.Series,
                       types: pd.Series,
                       fitted_models: Dict[str, Any]) -> pd.DataFrame:
    pred = {name: mdl.predict(X_test) for name, mdl in fitted_models.items()}
    df = pd.DataFrame({"y": y_test, "type": types}, index=X_test.index)
    out = []
    for name, yhat in pred.items():
        tmp = df.copy()
        tmp["yhat"] = yhat
        g = tmp.groupby("type")
        out.append(pd.DataFrame({
            "model": name,
            "n": g.size(),
            "R2": g[["y", "yhat"]].apply(lambda t: r2_score(t["y"], t["yhat"]) if t["y"].nunique() > 1 else np.nan),
            "RMSE": g[["y", "yhat"]].apply(lambda t: mean_squared_error(t["y"], t["yhat"])),
            "MAE": g[["y", "yhat"]].apply(lambda t: mean_absolute_error(t["y"], t["yhat"]))}).reset_index())
    return pd.concat(out, axis=0, ignore_index=True)


def make_case_table(
    X_test: pd.DataFrame,
    y_test: pd.Series,
    meta_df: Optional[pd.DataFrame],
    *,

    pred_Ridge: np.ndarray,
    pred_XGB:   np.ndarray,
    pred_HGB:   np.ndarray,
    pred_MLP:   np.ndarray,

    pred_IABMA: Optional[np.ndarray] = None,
    pred_MoE:   Optional[np.ndarray] = None,
    pred_DLA: Optional[np.ndarray] = None,
    pred_SMC:   Optional[np.ndarray] = None,   
    pred_BHS:   Optional[np.ndarray] = None,

    w_IABMA: Optional[np.ndarray] = None,      
    w_MoE:   Optional[np.ndarray] = None,
    w_DLA: Optional[np.ndarray] = None,
    w_SMC:   Optional[np.ndarray] = None,      
    w_BHS:   Optional[np.ndarray] = None,

    case_id_col: str = "case_id",
    primary_site_col: str = "primary_site",
    compound_col: str = "compound",
) -> pd.DataFrame:

    idx = X_test.index
    out = pd.DataFrame(index=idx)
    if meta_df is not None:
        for col in [case_id_col, primary_site_col, compound_col]:
            if col in meta_df.columns:
                out[col] = meta_df.loc[idx, col].values
    out["y_true"] = y_test.values

    out["pred_Ridge"] = np.asarray(pred_Ridge).ravel()
    out["pred_XGB"]   = np.asarray(pred_XGB).ravel()
    out["pred_HGB"]   = np.asarray(pred_HGB).ravel()
    out["pred_MLP"]   = np.asarray(pred_MLP).ravel()

    if pred_IABMA is not None:   out["pred_IABMA"]   = np.asarray(pred_IABMA).ravel()
    if pred_MoE is not None:     out["pred_MoE"]     = np.asarray(pred_MoE).ravel()
    if pred_DLA is not None:     out["pred_DLA"] = np.asarray(pred_DLA).ravel()
    if pred_SMC is not None:     out["pred_SMC"]     = np.asarray(pred_SMC).ravel()   
    if pred_BHS is not None:     out["pred_BHS"]     = np.asarray(pred_BHS).ravel()

    # RMSE 
    def _rmse_col(pred):
        return np.sqrt((out["y_true"] - pred) ** 2)

    for key in ["IABMA", "MoE", "DLA", "SMC", "BHS"]:
        col = f"pred_{key}"
        if col in out.columns:
            out[f"RMSE_{key}"] = _rmse_col(out[col])
        else:
            out[f"RMSE_{key}"] = np.nan

    # Weights 
    base_keys = ["Ridge", "XGB", "HGB", "MLP"]
    def _add_weights(prefix: str, W: Optional[np.ndarray]):
        cols = [f"w_{prefix}_{bk}" for bk in base_keys]
        if W is None:
            for c in cols: out[c] = np.nan
            return
        W = np.asarray(W)
        if W.ndim != 2 or W.shape[1] != 4 or W.shape[0] != len(out):
            for c in cols: out[c] = np.nan
            return
        for j, bk in enumerate(base_keys):
            out[f"w_{prefix}_{bk}"] = W[:, j]

    _add_weights("IABMA",   w_IABMA)
    _add_weights("MoE",     w_MoE)
    _add_weights("DLA",     w_DLA)
    _add_weights("SMC",     w_SMC)
    _add_weights("BHS",     w_BHS)

    desired = [
        "case_id", "primary_site", "compound", "y_true",
        "pred_Ridge", "pred_XGB", "pred_HGB", "pred_MLP",
        "pred_IABMA", "pred_MoE", "pred_DLA", "pred_SMC", "pred_BHS",
        "RMSE_IABMA", "RMSE_MoE", "RMSE_DLA", "RMSE_SMC", "RMSE_BHS",
        "w_IABMA_Ridge","w_IABMA_XGB","w_IABMA_HGB","w_IABMA_MLP",
        "w_MoE_Ridge","w_MoE_XGB","w_MoE_HGB","w_MoE_MLP",
        "w_DLA_Ridge","w_DLA_XGB","w_DLA_HGB","w_DLA_MLP",
        "w_SMC_Ridge","w_SMC_XGB","w_SMC_HGB","w_SMC_MLP",
        "w_BHS_Ridge","w_BHS_XGB","w_BHS_HGB","w_BHS_MLP",
    ]
    cols = [c for c in desired if c in out.columns] + [c for c in out.columns if c not in desired]
    return out[cols]


def plot_cases_from_table(tbl, case_ids, *, base_order=None, method_order=None):

    desired_bases = ["ridge", "xgb", "hgb", "mlp"]
    pretty_base   = {"ridge": "Ridge", "xgb": "XGB", "hgb": "HGB", "mlp": "MLP"}

    pred_map = {}
    for c in tbl.columns:
        if isinstance(c, str):
            m = re.match(r'(?i)^pred_(.+)$', c)
            if m:
                b = m.group(1).lower()
                if b in desired_bases and b not in pred_map:
                    pred_map[b] = c

    if base_order is None:
        bases_for_errors = [b for b in desired_bases if b in pred_map]
    else:
        bases_for_errors = [str(b).lower() for b in base_order if str(b).lower() in pred_map]

    weight_map = {}
    for c in tbl.columns:
        if isinstance(c, str):
            m = re.match(r'(?i)^w_([^_]+)_([^_]+)$', c)
            if m:
                meth_tok = m.group(1)             
                base_tok = m.group(2).lower()
                if base_tok in desired_bases:
                    weight_map.setdefault(meth_tok, {})[base_tok] = c

    def disp_label(token: str) -> str:
        t = token.lower()
        if t == "dla":   return "DLA"
        if t == "moe":   return "MoE"
        if t == "iabma": return "IABMA"
        if t == "smc":   return "SMC"
        if t == "bhs":   return "BHS"
        return token

    tokens_present = [t for t, bmap in weight_map.items() if any(b in bmap for b in desired_bases)]
    if method_order is None:
        preferred = ["IABMA", "MoE", "DLA", "SMC", "BHS"]
        tok_order = []
        for p in preferred:
            match = next((t for t in tokens_present if t.lower() == p.lower()), None)
            if match: tok_order.append(match)
        tok_order += [t for t in tokens_present if all(t.lower() != p.lower() for p in preferred)]
    else:
        tok_order = []
        for req in method_order:
            match = next((t for t in tokens_present if disp_label(t).lower() == str(req).lower()), None)
            if match: tok_order.append(match)

    if base_order is None:
        bases_for_weights = [b for b in desired_bases if any(b in weight_map.get(t, {}) for t in tok_order)]
    else:
        bases_for_weights = [str(b).lower() for b in base_order if str(b).lower() in desired_bases]
    if not bases_for_weights:
        bases_for_weights = [b for b in desired_bases if any(b in weight_map.get(t, {}) for t in tok_order)]

    N = len(case_ids)
    inner_spacer = 0.30   
    outer_spacer = 0.55   
    pair_widths = []
    for i in range(N):
        pair_widths += [1.0, inner_spacer, 1.8]     
        if i < N - 1:
            pair_widths += [outer_spacer]           

    num_cols = len(pair_widths)
    fig_w = 1.6 * sum(pair_widths)
    fig_h = 3.0
    fig, axes = plt.subplots(
        1, num_cols, figsize=(fig_w, fig_h),
        gridspec_kw={"width_ratios": pair_widths, "wspace": 0.05}  
    )
    axes = np.array(axes)

    base_fill_color = "#9aa0a6"
    base_edge_color = "#606060"

    method_color = {
        "DLA":   "#1f77b4",  # blue
        "IABMA": "#00441b",  # deep green (hatched)
        "MoE":   "#ff7f0e",  # orange
        "SMC":   "#9467bd",  # purple
        "BHS":   "#d62728",  # red
    }
    method_disp = [disp_label(t) for t in tok_order]
    fallback_palette = iter(sns.color_palette("tab10", 10))
    for m in method_disp:
        if m not in method_color:
            method_color[m] = next(fallback_palette)

    pair_axes = []  
    col_idx = 0
    for p, cid in enumerate(case_ids):
        ax_err = axes[col_idx]
        ax_w   = axes[col_idx + 2]   
        pair_axes.append((ax_err, ax_w))

        axes[col_idx + 1].axis("off")                         
        if col_idx + 3 < num_cols:
            axes[col_idx + 3].axis("off")                    

        sub = tbl.loc[tbl["case_id"] == cid]
        if sub.empty:
            ax_err.axis("off"); ax_w.axis("off")
            col_idx += 4 if (p < N - 1) else 3
            continue
        row = sub.iloc[0]

        rmse_rows = []
        for b in [bb for bb in ["ridge","xgb","hgb","mlp"] if bb in pred_map]:
            colname = pred_map[b]
            val = row.get(colname, np.nan)
            if pd.notna(val):
                rmse = abs(float(row["y_true"]) - float(val))  
                rmse_rows.append((pretty_base[b], rmse))
        if rmse_rows:
            rmse_df = pd.DataFrame(rmse_rows, columns=["Base", "RMSE"])
            sns.barplot(
                data=rmse_df, x="Base", y="RMSE",
                ax=ax_err, errorbar=None, width=0.7,
                color=base_fill_color, edgecolor=base_edge_color, linewidth=0.5
            )
            ax_err.set_xlabel("")
            ax_err.set_ylabel("RMSE", labelpad=1)            
            ax_err.tick_params(axis="y", labelsize=8)        
            ax_err.grid(axis="y", alpha=0.35, linewidth=0.6)
            ax_err.tick_params(axis="x", labelsize=8)
            ax_err.margins(x=0.02)
        else:
            ax_err.text(0.5, 0.5, "No base preds", ha="center", va="center")
            ax_err.axis("off")

        B = len(bases_for_weights)
        H = len(method_disp)
        base_labels = [pretty_base[b] for b in bases_for_weights]
        W = np.full((H, B), np.nan, dtype=float)
        for i, tok in enumerate(tok_order):
            bmap = weight_map.get(tok, {})
            for j, b in enumerate(bases_for_weights):
                cname = bmap.get(b)
                if cname is not None:
                    v = row.get(cname, np.nan)
                    if pd.notna(v):
                        W[i, j] = float(v)

        ax_w.set_xlabel("")
        ax_w.set_ylabel("Weight", labelpad=1)                
        ax_w.tick_params(axis="y", labelsize=8)              
        ax_w.grid(axis="y", alpha=0.35, linewidth=0.6)
        ax_w.set_xticks(np.arange(B))
        ax_w.set_xticklabels(base_labels, ha="right", fontsize=8)
        ax_w.set_ylim(0, 1.05)

        if H == 0 or B == 0 or np.all(np.isnan(W)):
            ax_w.text(0.5, 0.5, "No weights", ha="center", va="center")
            ax_w.axis("off")
        else:
            total_width = 0.78
            bar_w = total_width / max(H, 1)
            offsets = (np.arange(H) - (H - 1) / 2) * bar_w
            x = np.arange(B)

            for i, m in enumerate(method_disp):
                color = method_color[m]
                hatch = "//" if m == "IABMA" else None
                edgec = "white" if m == "IABMA" else None
                lw    = 1.0 if m == "IABMA" else 0.0
                alpha = 1.0 if m == "IABMA" else 0.75

                mask = ~np.isnan(W[i])
                ax_w.bar(
                    x[mask] + offsets[i],
                    W[i, mask],
                    width=bar_w * 0.92,
                    label=m,
                    color=color,
                    hatch=hatch,
                    edgecolor=edgec,
                    linewidth=lw,
                    align="center",
                    alpha=alpha
                )

        for ax in (ax_err, ax_w):
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

        col_idx += 4 if (p < N - 1) else 3

    plt.subplots_adjust(bottom=0.44, top=0.92, left=0.05, right=0.99)
    fig = plt.gcf()
    fig.canvas.draw()  

    title_y_vals = []
    for (ax_err, ax_w), cid in zip(pair_axes, case_ids):
        box_err = ax_err.get_position()
        box_w   = ax_w.get_position()
        left  = box_err.x0
        right = box_w.x1
        pair_center = (left + right) / 2.0
        pair_width  = (right - left)
        x_center = pair_center - 0.13 * pair_width    # small left shift
        y_bottom = min(box_err.y0, box_w.y0)
        title_y = y_bottom - 0.1
        title_y_vals.append(title_y)
        fig.text(x_center, title_y, f"Case {cid}", ha="center", va="top", fontsize=9)

    if method_disp:
        legend_y = max(0.004, min(title_y_vals) - 0.125)  
        handles = [Patch(facecolor=method_color[m],
                         label=m,
                         hatch="//" if m == "IABMA" else None,
                         edgecolor="white" if m == "IABMA" else "none",
                         linewidth=1.0 if m == "IABMA" else 0.0)
                   for m in method_disp]
        fig.legend(
            handles=handles,
            ncol=min(len(handles), 5),
            frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, legend_y),
            fontsize=8,
            handlelength=1.4,
            columnspacing=0.8,
        )

    return fig

