import os
import argparse
from pathlib import Path
import sys
from typing import Optional

# Avoid PyTensor from recursing too deeply when cloning large graphs
sys.setrecursionlimit(10000)

import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
from scipy.special import expit, logit
import arviz as az
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
import logging

# Set log level
logger = logging.getLogger("pymc")
logger.setLevel(logging.INFO)

"""
python run_mix_benchmark.py --input_csv /path/to/merged_is_correct_matrix.csv --outputdir /path/to/output_dir --sample_ratio 0.2 --train_ratio 0.9 --draws 1500 --tune 1500 --seed 123

python yourpath/run_mix_benchmark.py --cores 10 --sample_ratio 0.01

python run_mix_benchmark.py --sample_ratio 0.01 --cores 4 --chains 2 --slope_mode bench --engine numpyro --draws 800 --tune 800 --target_accept 0.85

python run_mix_benchmark.py --sample_per_bench 100 --cores 4 --chains 2 --slope_mode rasch --engine pymc --draws 800 --tune 800

python run_mix_benchmark.py --train_per_bench 500 --val_per_bench 100 --test_per_bench 100 --cores 4 --chains 4 --draws 500 --slope_mode rasch --seed 123
"""

# ======= Split training/testing =======
def split_train_test(Y_full: np.ndarray, train_ratio: float = 0.9, seed: int = 42):
    rng = np.random.default_rng(seed)
    mask = rng.random(Y_full.shape) < train_ratio
    Y_train = Y_full.astype(float)
    Y_train[~mask] = np.nan
    Y_test_mask = ~mask
    return Y_train, Y_test_mask, Y_full

# ======= Multi-benchmark MCMC modeling =======
def build_MCMC_multi(N: int,
                     J: int,
                     M: int,
                     bench_idx: np.ndarray,
                     Y_train: np.ndarray,
                     draws: int = 2000,
                     tune: int = 3000,
                     chains: int = 2,
                     cores: int = 1,
                     engine: str = "pymc",
                     slope_mode: str = "item",
                     target_accept: float = 0.9) -> az.InferenceData:
    """
    Using factor matrix covariance assumption, jointly model multi-benchmark binary classification data in PyMC.
    bench_idx: Array of size J, values in [0, M).
    """
    with pm.Model() as model:
        # Global ability
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)
        # ζ covariance decomposition
        chol, corr, stds = pm.LKJCholeskyCov(
            'chol_zeta', eta=2.0, n=M,
            sd_dist=pm.Exponential.dist(1.0),
            compute_corr=True
        )
        zeta_raw = pm.Normal("zeta_raw", 0.0, 1.0, shape=(N, M))
        zeta = pm.Deterministic("zeta", pt.dot(zeta_raw, chol.T))  # (N, M)
        # Joint ability
        theta = psi[:, None] + zeta  # (N, M)
        # Item parameters
        if slope_mode == "rasch":
            a = pt.ones((J,), dtype="float64")
            a = pm.Deterministic("a", a)
        elif slope_mode == "bench":
            a_m = pm.LogNormal("a_m", mu=np.log(1.0), sigma=0.5, shape=M)
            a = a_m[bench_idx]
            a = pm.Deterministic("a", a)
        else:
            a = pm.LogNormal("a", mu=np.log(1.0), sigma=0.5, shape=J)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)
        # Build logit
        theta_items = theta[:, bench_idx]  # (N, J)
        logits = a[None, :] * theta_items - b[None, :]  # (N, J)
        # Construct likelihood only at observed locations to avoid introducing discrete latent variables for missing values (significant acceleration)
        obs_i, obs_j = np.where(~np.isnan(Y_train))
        y_obs_vec = Y_train[obs_i, obs_j].astype(int)
        logits_obs = logits[obs_i, obs_j]
        pm.Bernoulli("y_obs", logit_p=logits_obs, observed=y_obs_vec)
        print(f"  [MCMC] Start sampling: draws={draws}, tune={tune}, chains={chains}, cores={cores}, engine={engine}, slope_mode={slope_mode}")
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=target_accept,
            return_inferencedata=True,
            progressbar=True,
            nuts_sampler=engine,
            compute_convergence_checks="eager"
        )
        print("  [MCMC] Sampling completed")
    return trace

# ======= Convergence and mixing quality assessment =======
def evaluate_mcmc_reliability(trace: az.InferenceData, var_names: list[str], save_dir: Path):
    summary = az.summary(trace, var_names=var_names)
    # Save summary
    summary_path = save_dir / "mcmc_summary.csv"
    summary.to_csv(summary_path)
    rhat = summary["r_hat"]
    ess = summary["ess_bulk"]
    if (rhat > 1.01).any():
        print("⚠️ Convergence warning (r_hat > 1.01):")
        print(rhat[rhat > 1.01])
    else:
        print("✅ Parameters converged well")
    if (ess < 200).any():
        print("⚠️ Insufficient samples (ess_bulk < 200):")
        print(ess[ess < 200])
    else:
        print("✅ Sufficient effective samples")
    # # Trace plots
    # fig = az.plot_trace(trace, var_names=var_names)
    # plt.tight_layout()
    # plt.savefig(save_dir / "trace_plots.png", dpi=200)
    # plt.close()
    # # Energy plot
    # az.plot_energy(trace)
    # plt.tight_layout()
    # plt.savefig(save_dir / "energy_plot.png", dpi=200)
    # plt.close()

# ======= Estimate vs true value comparison (when true values are provided) =======
def check_diff(trace: az.InferenceData, true_data: dict, var_names: list[str], save_dir: Path):
    plt.figure(figsize=(4 * len(var_names), 4))
    for i, var in enumerate(var_names):
        hat = trace.posterior[var].mean(dim=("chain", "draw")).values
        true = true_data[var]
        plt.subplot(1, len(var_names), i + 1)
        plt.scatter(true.flatten(), hat.flatten(), alpha=0.6)
        mn, mx = np.min(true), np.max(true)
        plt.plot([mn, mx], [mn, mx], 'r--')
        plt.xlabel("True"); plt.ylabel("Estimated"); plt.title(var)
        plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_dir / "est_vs_true.png", dpi=200)
    plt.close()

# ======= Model performance evaluation =======
def evaluate_model(trace: az.InferenceData,
                   Y_true: np.ndarray,
                   Y_mask: np.ndarray,
                   bench_idx: np.ndarray,
                   unique_bench: list[str],
                   save_dirs: dict,
                   tag: str = "eval"):
    """
    First do overall evaluation, then evaluate each benchmark separately; save charts and reports.
    """
    hat = {var: trace.posterior[var].mean(dim=("chain","draw")).values
           for var in ['psi','zeta','a','b']}
    psi_hat, zeta_hat, a_hat, b_hat = hat['psi'], hat['zeta'], hat['a'], hat['b']
    N, J = Y_true.shape

    theta_items = psi_hat[:,None] + zeta_hat[:,bench_idx]
    logits = a_hat[None,:] * theta_items - b_hat[None,:]
    p_pred = expit(logits)
    y_pred = (p_pred > 0.5).astype(int)

    # Overall
    mask_all = ~np.isnan(Y_true) if Y_mask is None else Y_mask
    y_true_all = Y_true[mask_all]
    p_all      = p_pred[mask_all]
    yhat_all   = y_pred[mask_all]
    acc_all = accuracy_score(y_true_all, yhat_all)
    auc_all = roc_auc_score(y_true_all, p_all)
    cm_all  = confusion_matrix(y_true_all, yhat_all)
    print(f"[{tag}][Overall] Accuracy: {acc_all:.4f}, AUC: {auc_all:.4f}")
    print("Confusion Matrix:\n", cm_all)

    # Save report
    with open(save_dirs["root"] / f"{tag}_report.txt", "w") as f:
        f.write(f"[{tag}] Overall Accuracy: {acc_all:.6f}, AUC: {auc_all:.6f}\n")
        f.write(f"Confusion Matrix:\n{cm_all}\n")

    # Scatter: true overall acc vs psi_hat
    true_acc_overall = np.nanmean(Y_true, axis=1)
    plt.figure(figsize=(5,4))
    plt.scatter(true_acc_overall, psi_hat, alpha=0.6)
    plt.xlabel("True Overall Accuracy"); plt.ylabel("Estimated psi")
    plt.title(f"Overall: True Acc vs. psi_hat ({tag})")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_dirs["root"] / f"{tag}_overall_scatter.png", dpi=200)
    plt.close()

    # Per-benchmark
    for m, name in enumerate(unique_bench):
        idx = (bench_idx == m)
        true_acc = np.nanmean(Y_true[:, idx], axis=1)
        theta_m  = psi_hat + zeta_hat[:, m]

        # Scatter
        plt.figure(figsize=(5,4))
        plt.scatter(true_acc, theta_m, alpha=0.6)
        plt.xlabel(f"True Acc ({name})"); plt.ylabel(f"Theta_hat ({name})")
        plt.title(f"{name}: True Acc vs. Theta_hat ({tag})")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(save_dirs[name] / f"{tag}_{name}_scatter.png", dpi=200)
        plt.close()

        # Metrics on subset (only on provided mask)
        mask_m = mask_all & np.tile(idx, (N, 1))
        y_true_m = Y_true[mask_m]
        yhat_m   = y_pred[mask_m]
        p_m      = p_pred[mask_m]
        if len(y_true_m) > 0:
            acc_m = accuracy_score(y_true_m, yhat_m)
            auc_m = roc_auc_score(y_true_m, p_m)
            cm_m  = confusion_matrix(y_true_m, yhat_m)
            with open(save_dirs["root"] / f"{tag}_report.txt", "a") as f:
                f.write(f"[{name}] Accuracy: {acc_m:.6f}, AUC: {auc_m:.6f}\n")
                f.write(f"[{name}] Confusion Matrix:\n{cm_m}\n")
            print(f"[{tag}][{name}] Accuracy: {acc_m:.4f}, AUC: {auc_m:.4f}")
        else:
            with open(save_dirs["root"] / f"{tag}_report.txt", "a") as f:
                f.write(f"[{name}] No data under mask.\n")
            print(f"[{tag}][{name}] No data under mask.")


# ======= Correlation Visualization =======
def save_correlation_heatmaps(Y_true: np.ndarray,
                              bench_idx: np.ndarray,
                              trace: az.InferenceData,
                              unique_bench: list[str],
                              save_dirs: dict,
                              tag: str = "corr"):
    import seaborn as sns
    N, M = Y_true.shape[0], len(unique_bench)

    acc = np.array([
        [np.nanmean(Y_true[i, bench_idx == m]) for m in range(M)]
        for i in range(N)
    ])
    raw_corr = np.corrcoef(acc, rowvar=False)

    post_corr = trace.posterior['chol_zeta_corr'].mean(dim=('chain','draw')).values
    post_stds = trace.posterior['chol_zeta_stds'].mean(dim=('chain','draw')).values
    post_cov = post_corr * np.outer(post_stds, post_stds)

    # Save matrices
    np.savetxt(save_dirs["root"] / f"{tag}_raw_accuracy_corr.csv", raw_corr, delimiter=",")
    np.savetxt(save_dirs["root"] / f"{tag}_post_zeta_corr.csv", post_corr, delimiter=",")
    np.savetxt(save_dirs["root"] / f"{tag}_post_zeta_cov.csv", post_cov, delimiter=",")

    # Heatmaps
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    sns.heatmap(raw_corr, annot=True, ax=axes[0], cmap='vlag', vmin=-1, vmax=1,
                xticklabels=unique_bench, yticklabels=unique_bench)
    axes[0].set_title("Raw Accuracy Correlation")
    sns.heatmap(post_corr, annot=True, ax=axes[1], cmap='vlag', vmin=-1, vmax=1,
                xticklabels=unique_bench, yticklabels=unique_bench)
    axes[1].set_title("Posterior ζ Correlation")
    plt.tight_layout()
    plt.savefig(save_dirs["root"] / f"{tag}_heatmaps.png", dpi=200)
    plt.close()

# ======= Save Model Parameters (Consistent with run_benchmark_single.py Format) =======
def save_model_parameters(trace: az.InferenceData,
                         models: list[str],
                         unique_bench: list[str],
                         save_dirs: dict):
    """
    Save model parameters, format consistent with run_benchmark_single.py
    """
    # Save psi parameter (global ability)
    psi_hat = trace.posterior['psi'].mean(dim=('chain','draw')).values
    psi_df = pd.DataFrame({"model": models, "theta_hat": psi_hat})
    psi_path = save_dirs["root"] / "theta_hat_psi.csv"
    psi_df.to_csv(psi_path, index=False)
    
    # Save zeta parameter (benchmark-specific ability)
    zeta_hat = trace.posterior['zeta'].mean(dim=('chain','draw')).values
    for m, name in enumerate(unique_bench):
        zeta_m = zeta_hat[:, m]
        zeta_df = pd.DataFrame({"model": models, "theta_hat": zeta_m})
        zeta_path = save_dirs[name] / "theta_hat_zeta.csv"
        zeta_df.to_csv(zeta_path, index=False)
    
    # Save item parameters
    a_hat = trace.posterior['a'].mean(dim=('chain','draw')).values
    b_hat = trace.posterior['b'].mean(dim=('chain','draw')).values
    item_df = pd.DataFrame({"a_hat": a_hat, "b_hat": b_hat})
    item_path = save_dirs["root"] / "item_params.csv"
    item_df.to_csv(item_path, index=False)
    
    print(f"Model parameters saved to {save_dirs['root']}")

# ======= CLI Main Process: Using Real Data =======
def build_inputs_from_csv(csv_path: Path,
                          sample_ratio: float,
                          seed: int = 42,
                          sample_per_bench: int | None = None,
                          train_per_bench: int | None = None,
                          val_per_bench: int | None = None,
                          test_per_bench: int | None = None):
    """
    Read the merged is_correct matrix, sample by benchmark (optional), and return
    N, J, M, models, unique_bench, bench_idx, J_list, Y_full.
    When sample_per_bench is provided, prioritize fixed sampling of that number of items for each benchmark (take all available if insufficient), ignoring sample_ratio.
    When any of train_per_bench / val_per_bench / test_per_bench is provided, enable "fixed set sampling" mode: for each BENCH, fixedly sample train/val/test quantities (try to satisfy as much as possible when insufficient, the sum of the three serves as the total sampling upper limit for the benchmark). In this mode, sample_ratio and sample_per_bench are ignored.
    """
    df = pd.read_csv(csv_path, index_col=0)
    # Only recognize CEVAL/CSQA/MMLU three BENCHes, strict binning
    allowed_benches = ["CEVAL", "CSQA", "MMLU"]
    def parse_bench(col: str) -> str:
        # Column name format: {BENCH}_{qid}_is_correct; here only take the first underscore prefix
        return col.split("_", 1)[0]
    # Filter out non-allowed BENCH columns
    allowed_col_idx = [i for i, c in enumerate(df.columns) if parse_bench(c) in allowed_benches]
    excluded_col_idx = [i for i, c in enumerate(df.columns) if parse_bench(c) not in allowed_benches]
    if len(excluded_col_idx) > 0:
        excl_names = [df.columns[i] for i in excluded_col_idx[:10]]
        print(f"[Sanity] Found and excluded {len(excluded_col_idx)} columns of non-target BENCH (first 10 examples): {excl_names}")
    # Subset to only include target BENCH columns (maintain original order)
    df = df.iloc[:, allowed_col_idx]
    models = df.index.tolist()
    # Re-parse benchmark labels (fixed order: CEVAL→CSQA→MMLU)
    col2bench = [parse_bench(col) for col in df.columns]
    unique_bench = allowed_benches[:]  # Fixed three types
    bench_idx_full = np.array([ unique_bench.index(b) for b in col2bench ], dtype=int)
    J_list_full = [ sum(b == ub for b in col2bench) for ub in unique_bench ]
    # Print statistics
    bench_counts = {ub: J_list_full[i] for i, ub in enumerate(unique_bench)}
    print("[Sanity] bench_counts =", bench_counts)
    # Target sampling mode
    use_setwise = any(x is not None for x in (train_per_bench, val_per_bench, test_per_bench))
    if use_setwise:
        t_train = int(train_per_bench or 0)
        t_val   = int(val_per_bench or 0)
        t_test  = int(test_per_bench or 0)
        t_total = t_train + t_val + t_test
        if t_total <= 0:
            raise ValueError("When fixed set sampling is enabled, train_per_bench + val_per_bench + test_per_bench must be > 0")

    # Sample column indices by benchmark (reproducibility)
    rng = np.random.default_rng(seed)
    selected_cols = []
    selected_bench_idx = []
    train_cols_list = []
    val_cols_list = []
    test_cols_list = []
    for m, ub in enumerate(unique_bench):
        idx_m = np.where(bench_idx_full == m)[0]
        if len(idx_m) == 0:
            continue
        if use_setwise:
            k_total = min(t_total, len(idx_m))
            if k_total <= 0:
                continue
            choose = rng.choice(idx_m, size=k_total, replace=False)
            # Split train/val/test
            k_tr = min(t_train, k_total)
            k_v  = min(t_val,  max(0, k_total - k_tr))
            k_te = min(t_test, max(0, k_total - k_tr - k_v))
            tr = choose[:k_tr]
            va = choose[k_tr:k_tr+k_v]
            te = choose[k_tr+k_v:k_tr+k_v+k_te]
            train_cols_list.append(tr)
            val_cols_list.append(va)
            test_cols_list.append(te)
            selected_cols.append(choose)
            selected_bench_idx.append(np.full(k_total, m, dtype=int))
        else:
            if sample_per_bench is not None:
                k = min(int(sample_per_bench), len(idx_m))
                if k <= 0:
                    continue
            else:
                k = max(1, int(len(idx_m) * sample_ratio))
                k = min(k, len(idx_m))
            choose = rng.choice(idx_m, size=k, replace=False)
            selected_cols.append(choose)
            selected_bench_idx.append(np.full(k, m, dtype=int))
    selected_cols = np.concatenate(selected_cols) if len(selected_cols) else np.array([], dtype=int)
    selected_bench_idx = np.concatenate(selected_bench_idx) if len(selected_bench_idx) else np.array([], dtype=int)
    if use_setwise:
        train_cols = np.concatenate(train_cols_list) if len(train_cols_list) else np.array([], dtype=int)
        val_cols   = np.concatenate(val_cols_list) if len(val_cols_list) else np.array([], dtype=int)
        test_cols  = np.concatenate(test_cols_list) if len(test_cols_list) else np.array([], dtype=int)
    else:
        train_cols = val_cols = test_cols = None

    # Print actual sampling count for each BENCH
    if use_setwise:
        eff_train = {ub: int(np.sum(np.isin(selected_cols, train_cols) & (selected_bench_idx==i))) for i, ub in enumerate(unique_bench)}
        eff_val   = {ub: int(np.sum(np.isin(selected_cols, val_cols) & (selected_bench_idx==i))) for i, ub in enumerate(unique_bench)}
        eff_test  = {ub: int(np.sum(np.isin(selected_cols, test_cols) & (selected_bench_idx==i))) for i, ub in enumerate(unique_bench)}
        print("[Sampling] per_bench={train/val/test} =", eff_train, eff_val, eff_test, "(mode: fixed sets)")
    else:
        effective_counts = {ub: int(np.sum(selected_bench_idx == i)) for i, ub in enumerate(unique_bench)}
        print("[Sampling] effective_per_bench =", effective_counts, "(mode:",
              ("fixed" if sample_per_bench is not None else f"ratio={sample_ratio}"), ")")

    # Subset
    df_sub = df.iloc[:, selected_cols]
    bench_idx = selected_bench_idx
    J_list = [ int(np.sum(bench_idx == i)) for i in range(len(unique_bench)) ]

    Y_full = df_sub.values.astype(int)
    N, J = Y_full.shape
    M = len(unique_bench)

    # Map original column indices to subset column space
    if use_setwise:
        col_map = {orig_idx: new_pos for new_pos, orig_idx in enumerate(selected_cols.tolist())}
        train_pos = np.array([col_map[i] for i in train_cols], dtype=int) if len(selected_cols)>0 else np.array([], dtype=int)
        val_pos   = np.array([col_map[i] for i in val_cols], dtype=int) if len(selected_cols)>0 else np.array([], dtype=int)
        test_pos  = np.array([col_map[i] for i in test_cols], dtype=int) if len(selected_cols)>0 else np.array([], dtype=int)
    else:
        train_pos = val_pos = test_pos = None

    return N, J, M, models, unique_bench, bench_idx, J_list, Y_full, train_pos, val_pos, test_pos

def main():
    parser = argparse.ArgumentParser(description="Run multi-benchmark IRT on merged is_correct matrix.")
    parser.add_argument("--input_csv", type=str,
                        default="/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/data_result_raw/merged_mix_benchmark/merged_is_correct_matrix.csv",
                        help="Path to merged *_is_correct matrix CSV.")
    parser.add_argument("--outputdir", type=str,
                        default="yourpath/result_mixbenchmark",
                        help="Directory to write results.")
    parser.add_argument("--sample_ratio", type=float, default=0.1,
                        help="Per-benchmark column sampling ratio in (0,1].")
    parser.add_argument("--sample_per_bench", type=int, default=None,
                        help="If set, sample exactly this many items per benchmark (overrides sample_ratio).")
    parser.add_argument("--train_ratio", type=float, default=0.9,
                        help="Train ratio for masking.")
    parser.add_argument("--draws", type=int, default=1000, help="MCMC draws.")
    parser.add_argument("--tune", type=int, default=1000, help="MCMC tune.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--cores", type=int, default=1, help="Number of CPU cores for sampling (set 1 to ensure visible tqdm).")
    parser.add_argument("--chains", type=int, default=2, help="Number of MCMC chains (reduce to speed up).")
    parser.add_argument("--engine", type=str, default="pymc", choices=["pymc","numpyro","blackjax"], help="NUTS backend.")
    parser.add_argument("--slope_mode", type=str, default="item", choices=["item","bench","rasch"], help="Discrimination parameterization: per-item, per-benchmark, or Rasch (a=1).")
    parser.add_argument("--target_accept", type=float, default=0.9, help="NUTS target_accept.")
    parser.add_argument("--train_per_bench", type=int, default=None, help="Per-benchmark number of items for TRAIN split.")
    parser.add_argument("--val_per_bench", type=int, default=None, help="Per-benchmark number of items for VALID split.")
    parser.add_argument("--test_per_bench", type=int, default=None, help="Per-benchmark number of items for TEST split.")
    args = parser.parse_args()

    os.environ["PYMC_CORES_OVERRIDE"] = str(max(1, int(args.cores)))
    print(f"[Config] Using cores={os.environ['PYMC_CORES_OVERRIDE']} (can be modified via --cores)")

    outdir = Path(args.outputdir)
    outdir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories for each benchmark
    save_dirs = {"root": outdir}
    for bench in ["CEVAL", "CSQA", "MMLU"]:
        bench_dir = outdir / bench
        bench_dir.mkdir(parents=True, exist_ok=True)
        save_dirs[bench] = bench_dir

    print("[Step 1/6] Reading input data and building matrix...")
    # 1) Build input
    (N, J, M, models, unique_bench, bench_idx, J_list, Y_full, train_pos, val_pos, test_pos) = build_inputs_from_csv(
        Path(args.input_csv), sample_ratio=args.sample_ratio, seed=args.seed,
        sample_per_bench=args.sample_per_bench,
        train_per_bench=args.train_per_bench, val_per_bench=args.val_per_bench, test_per_bench=args.test_per_bench
    )
    print(f"N={N}, J={J}, M={M}")
    print("unique_bench =", unique_bench)
    print("J_list =", J_list)
    print("[Step 2/6] Splitting train/test sets...")

    # 2) Split
    use_setwise = any(x is not None for x in (args.train_per_bench, args.val_per_bench, args.test_per_bench))
    if use_setwise:
        # Only keep train columns as observations, set other columns to NaN; validation/test evaluation is done through column masks
        Y_true = Y_full.copy()
        Y_train = Y_full.astype(float).copy()
        col_mask = np.zeros(J, dtype=bool)
        if train_pos is not None and len(train_pos) > 0:
            col_mask[train_pos] = True
        Y_train[:, ~col_mask] = np.nan
        # Construct column-level mask to 2D mask
        def cols_to_mask(cols: np.ndarray) -> np.ndarray:
            if cols is None or len(cols) == 0:
                return np.zeros_like(Y_true, dtype=bool)
            m = np.zeros_like(Y_true, dtype=bool)
            m[:, cols] = True
            return m
        Y_train_mask = cols_to_mask(train_pos)
        Y_val_mask   = cols_to_mask(val_pos)
        Y_test_mask  = cols_to_mask(test_pos)
    else:
        Y_train, Y_test_mask, Y_true = split_train_test(Y_full, train_ratio=args.train_ratio, seed=args.seed)
        Y_val_mask = None
        Y_train_mask = ~np.isnan(Y_train)

    print("[Step 3/6] Starting MCMC sampling...")
    # 3) Sampling
    trace = build_MCMC_multi(N=N, J=J, M=M, bench_idx=bench_idx, Y_train=Y_train,
                             draws=args.draws, tune=args.tune,
                             chains=args.chains, cores=int(os.environ["PYMC_CORES_OVERRIDE"]),
                             engine=args.engine, slope_mode=args.slope_mode,
                             target_accept=args.target_accept)
    print("[Step 4/6] Sampling completed, starting convergence diagnosis...")

    # 4) Diagnostics
    evaluate_mcmc_reliability(trace, ["psi", "zeta", "a", "b"], save_dirs["root"])

    print("[Step 5/6] Evaluating train/val/test set performance...")
    print("\n===== Train set evaluation =====")
    evaluate_model(trace, Y_true, Y_train_mask if use_setwise else ~np.isnan(Y_train), bench_idx, unique_bench, save_dirs, tag="train")
    if use_setwise and Y_val_mask is not None and Y_val_mask.any():
        print("\n===== Validation set evaluation =====")
        evaluate_model(trace, Y_true, Y_val_mask, bench_idx, unique_bench, save_dirs, tag="valid")
    print("\n===== Test set evaluation =====")
    evaluate_model(trace, Y_true, Y_test_mask if use_setwise else np.isnan(Y_train), bench_idx, unique_bench, save_dirs, tag="test")

    print("[Step 6/6] Saving correlation analysis visualization...")
    # 6) Correlation visualization
    save_correlation_heatmaps(Y_true, bench_idx, trace, unique_bench, save_dirs, tag="corr")
    
    # 7) Save model parameters
    print("[Step 7/7] Saving model parameters...")
    save_model_parameters(trace, models, unique_bench, save_dirs)

if __name__ == "__main__":
    main()