#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Improved multi-benchmark IRT modeling with better handling of cross-benchmark information
"""

import os
import time
import warnings
from typing import Tuple, Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
import pytensor.tensor as pt
from scipy.special import expit, logit
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix

# Global Config
plt.rcParams["font.family"] = ["Arial", "Helvetica"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
warnings.filterwarnings("ignore")

# =========================
# 1. Core I/O & Preprocessing
# =========================
def build_inputs_from_csv(csv_path: str,
                          sample_ratio: float,
                          seed: int = 42,
                          sample_per_bench: int | None = None) -> Tuple[int, int, int, List[str], List[str], np.ndarray, List[int], np.ndarray]:
    """
    Read merged is_correct matrix and build input data for mixed benchmark.
    """
    df = pd.read_csv(csv_path, index_col=0)
    
    # Only recognize CEVAL/CSQA/MMLU three benchmarks
    allowed_benches = ["CEVAL", "CSQA", "MMLU"]
    def parse_bench(col: str) -> str:
        return col.split("_", 1)[0]
    
    # Filter out non-allowed benchmark columns
    allowed_col_idx = [i for i, c in enumerate(df.columns) if parse_bench(c) in allowed_benches]
    excluded_col_idx = [i for i, c in enumerate(df.columns) if parse_bench(c) not in allowed_benches]
    if len(excluded_col_idx) > 0:
        excl_names = [df.columns[i] for i in excluded_col_idx[:10]]
        print(f"[Sanity] Excluded {len(excluded_col_idx)} non-target benchmark columns (first 10 examples): {excl_names}")
    
    # Subset to only target benchmark columns (maintain original order)
    df = df.iloc[:, allowed_col_idx]
    models = df.index.tolist()
    
    # Re-parse benchmark labels (fixed order: CEVAL→CSQA→MMLU)
    col2bench = [parse_bench(col) for col in df.columns]
    unique_bench = allowed_benches[:]  # Fixed three types
    bench_idx_full = np.array([ unique_bench.index(b) for b in col2bench ], dtype=int)
    J_list_full = [ sum(b == ub for b in col2bench) for ub in unique_bench ]
    
    # Print statistics
    bench_counts = {ub: J_list_full[i] for i, ub in enumerate(unique_bench)}
    print("[Sanity] bench_counts =", bench_counts)
    
    # Sample column indices by benchmark (reproducibility)
    rng = np.random.default_rng(seed)
    selected_cols = []
    selected_bench_idx = []
    
    for m, ub in enumerate(unique_bench):
        idx_m = np.where(bench_idx_full == m)[0]
        if len(idx_m) == 0:
            continue
        if sample_per_bench is not None:
            k = min(int(sample_per_bench), len(idx_m))
            if k <= 0:
                continue
        else:
            k = max(1, int(len(idx_m) * sample_ratio))
            k = min(k, len(idx_m))
        choose = rng.choice(idx_m, size=k, replace=False)
        selected_cols.append(choose)
        selected_bench_idx.append(np.full(k, m, dtype=int))
    
    selected_cols = np.concatenate(selected_cols) if len(selected_cols) else np.array([], dtype=int)
    selected_bench_idx = np.concatenate(selected_bench_idx) if len(selected_bench_idx) else np.array([], dtype=int)
    
    # Print actual sampling count for each benchmark
    effective_counts = {ub: int(np.sum(selected_bench_idx == i)) for i, ub in enumerate(unique_bench)}
    print("[Sampling] effective_per_bench =", effective_counts, "(mode:",
          ("fixed" if sample_per_bench is not None else f"ratio={sample_ratio}"), ")")

    # Subset
    df_sub = df.iloc[:, selected_cols]
    bench_idx = selected_bench_idx
    J_list = [ int(np.sum(bench_idx == i)) for i in range(len(unique_bench)) ]

    Y_full = df_sub.values.astype(int)
    N, J = Y_full.shape
    M = len(unique_bench)

    return N, J, M, models, unique_bench, bench_idx, J_list, Y_full


# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set and remaining pool.
    Returns: test_mask (True=test sample), remaining_mask (True=non-test sample)
    """
    rng = np.random.default_rng(test_seed)
    test_mask = rng.random(Y_shape) < test_ratio
    remaining_mask = ~test_mask
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool for 1 repetition.
    Returns: train_mask (True=training sample)
    """
    rng = np.random.default_rng(rep_seed)
    train_mask = np.zeros_like(remaining_mask, dtype=bool)

    remaining_rows, remaining_cols = np.where(remaining_mask)
    n_remaining = len(remaining_rows)
    n_train = int(n_remaining * train_subset_ratio)

    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True

    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    """Record data split information"""
    N = len(model_names)
    J = len(question_names)
    
    for rep_idx, train_mask in enumerate(train_masks):
        if train_mask.shape != (N, J):
            raise ValueError(
                f"Train mask shape {train_mask.shape} doesn't match data dimensions ({N},{J})"
            )
    
    for rep_idx, train_mask in enumerate(train_masks):
        # Extract train set indices
        train_rows, train_cols = np.where(train_mask)
        # Extract test set indices
        test_rows, test_cols = np.where(test_mask)
        
        # Generate train set records
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # Generate test set records
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # Save records
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.3f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"Data split records saved (train ratio: {train_subset_ratio:.3f})")


# =========================
# 3. Prediction Methods (Baselines)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Naive Method: Predict all samples with global mean of training data."""
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5
    return np.full_like(Y, global_mean)


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Average Score: Predict each model's samples with its training mean."""
    N, J = Y.shape
    row_means = np.zeros(N)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    for i in range(N):
        model_train_mask = train_mask[i, :]
        if model_train_mask.any():
            row_means[i] = Y[i, model_train_mask].mean()
        else:
            row_means[i] = global_mean

    return np.tile(row_means.reshape(-1, 1), (1, J))


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Difficulty Modeling: Predict each question's samples with its training mean."""
    N, J = Y.shape
    col_means = np.zeros(J)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    for j in range(J):
        question_train_mask = train_mask[:, j]
        if question_train_mask.any():
            col_means[j] = Y[question_train_mask, j].mean()
        else:
            col_means[j] = global_mean

    return np.tile(col_means.reshape(1, -1), (N, 1))


# =========================
# 4. Improved Multi-Benchmark IRT Models
# =========================
def fit_improved_irt_multibench(
    Y: np.ndarray,
    train_mask: np.ndarray,
    bench_idx: np.ndarray,
    M: int,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None,
    target_accept: float = 0.95
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    Improved Multi-benchmark Joint IRT Model with better cross-benchmark information sharing.
    """
    N, J = Y.shape
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)

    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as model:
        # Global ability for each model
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)
        
        # Benchmark-specific abilities with shared covariance structure
        # Using LKJ prior for correlation matrix to model relationships between benchmarks
        chol, corr, stds = pm.LKJCholeskyCov(
            'chol_zeta', eta=2.0, n=M,
            sd_dist=pm.Exponential.dist(1.0),
            compute_corr=True
        )
        zeta_raw = pm.Normal("zeta_raw", 0.0, 1.0, shape=(N, M))
        zeta = pm.Deterministic("zeta", pt.dot(zeta_raw, chol.T))  # (N, M)
        
        # Joint ability: combine global and benchmark-specific abilities
        # This allows models to have both general ability and benchmark-specific skills
        theta = psi[:, None] + zeta  # (N, M)
        
        # Item parameters with benchmark-level discrimination
        # This allows items from the same benchmark to share similar discrimination properties
        a_m = pm.LogNormal("a_m", mu=np.log(1.0), sigma=0.5, shape=M)  # Benchmark-level discrimination
        a = a_m[bench_idx]  # Item-level discrimination based on benchmark
        a = pm.Deterministic("a", a)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)  # Item difficulty
        
        # Build logits using joint abilities
        theta_items = theta[:, bench_idx]  # (N, J) - select appropriate benchmark abilities for each item
        logits = a[None, :] * theta_items - b[None, :]  # (N, J)
        
        # Construct likelihood only for observed positions to avoid discrete latent variables
        obs_i, obs_j = np.where(train_mask)
        y_obs_vec = Y[obs_i, obs_j].astype(int)
        logits_obs = logits[obs_i, obs_j]
        pm.Bernoulli("y_obs", logit_p=logits_obs, observed=y_obs_vec)

        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=target_accept,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"Improved Multi-benchmark IRT sampling time: {time.time()-start_time:.1f}s")

    r_hat = az.summary(trace, var_names=["psi", "zeta", "a", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ Improved Multi-benchmark IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    params = {
        "psi": trace.posterior["psi"].mean(dim=["chain", "draw"]).values,
        "zeta": trace.posterior["zeta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def predict_from_multibench_irt(params: Dict[str, np.ndarray], bench_idx: np.ndarray) -> np.ndarray:
    """
    Generate predictions from multi-benchmark IRT posterior parameters.
    """
    psi = params["psi"]
    zeta = params["zeta"]
    a = params["a"]
    b = params["b"]
    N, M = zeta.shape
    J = len(bench_idx)
    
    # Joint ability
    theta_items = psi[:, None] + zeta[:, bench_idx]  # (N, J)
    # Predictions
    logits = a[None, :] * theta_items - b[None, :]  # (N, J)
    return expit(logits)


# =========================
# 5. MSE Calculation & Result Saving
# =========================
def calculate_mse(y_true: np.ndarray, y_pred: np.ndarray, test_mask: np.ndarray) -> float:
    """
    Calculate MSE only on test set samples.
    """
    y_true_test = y_true[test_mask]
    y_pred_test = y_pred[test_mask]
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)


def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_multibench: Optional[np.ndarray] = None
):
    """
    Save sample-level test set predictions to CSV.
    """
    # Extract test set indices and values
    test_rows, test_cols = np.where(test_mask)
    true_values = Y[test_mask].tolist()
    global_preds = pred_global[test_mask].tolist()
    row_preds = pred_row[test_mask].tolist()
    col_preds = pred_col[test_mask].tolist()

    # Handle multi-benchmark model predictions
    multibench_preds = pred_multibench[test_mask].tolist() if pred_multibench is not None else [np.nan] * len(test_rows)

    # Build DataFrame
    pred_df = pd.DataFrame({
        "model_idx": test_rows,
        "model_name": [model_names[r] for r in test_rows],
        "question_idx": test_cols,
        "question_name": [question_names[c] for c in test_cols],
        "train_ratio": train_subset_ratio,
        "repetition": rep + 1,
        "true_value": true_values,
        "global_mean_pred": global_preds,
        "model_mean_pred": row_preds,
        "question_mean_pred": col_preds,
        "multibench_irt_pred": multibench_preds
    })

    # Create output directory
    os.makedirs(pred_dir, exist_ok=True)

    # Save CSV file
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"✅ Sample-level predictions saved: {save_path}")
    print(f"   Total {len(pred_df)} test samples")


def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full IRT MCMC trace to file.
    """
    save_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    az.to_netcdf(trace, save_path)
    print(f"IRT {model_type} trace saved: {save_path}")


def save_model_parameters(
    params_dir: str,
    params: Dict[str, np.ndarray],
    model_names: List[str],
    question_names: List[str],
    unique_bench: List[str],
    trace: Optional[az.InferenceData] = None
):
    """
    Save model parameters in consistent format.
    """
    # Save psi parameters (global ability)
    psi_hat = params["psi"]
    psi_df = pd.DataFrame({"model": model_names, "theta_hat": psi_hat})
    psi_path = os.path.join(params_dir, "theta_hat_psi.csv")
    psi_df.to_csv(psi_path, index=False)
    
    # Save zeta parameters (benchmark-specific ability)
    zeta_hat = params["zeta"]
    for m, name in enumerate(unique_bench):
        zeta_m = zeta_hat[:, m]
        zeta_df = pd.DataFrame({"model": model_names, "theta_hat": zeta_m})
        zeta_path = os.path.join(params_dir, f"theta_hat_zeta_{name}.csv")
        zeta_df.to_csv(zeta_path, index=False)
    
    # Save zeta posterior variance if trace is provided
    if trace is not None:
        # Calculate posterior variance for zeta
        zeta_posterior = trace.posterior["zeta"]
        zeta_var = zeta_posterior.var(dim=["chain", "draw"]).values
        
        # Save zeta posterior variance for each benchmark
        for m, name in enumerate(unique_bench):
            zeta_var_m = zeta_var[:, m]
            zeta_var_df = pd.DataFrame({"model": model_names, "posterior_variance": zeta_var_m})
            zeta_var_path = os.path.join(params_dir, f"zeta_posterior_variance_{name}.csv")
            zeta_var_df.to_csv(zeta_var_path, index=False)
            print(f"Zeta posterior variance for {name} saved to {zeta_var_path}")
    
    # Save item parameters
    a_hat = params["a"]
    b_hat = params["b"]
    item_df = pd.DataFrame({"a_hat": a_hat, "b_hat": b_hat})
    item_path = os.path.join(params_dir, "item_params.csv")
    item_df.to_csv(item_path, index=False)
    
    print(f"Model parameters saved to {params_dir}")


# =========================
# 6. Result Visualization
# =========================
def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE vs. Training Data Ratio for all methods.
    """
    method_styles = {
        "Global Mean": ("blue", "solid"),
        "Model Mean": ("orange", "dashed"),
        "Question Mean": ("green", "dashdot"),
        "Multi-benchmark IRT": ("red", "solid")
    }

    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles[method]
        plt.plot(
            train_ratios,
            mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    plt.xlabel("Training Data Ratio", fontsize=12)
    plt.ylabel("Test Set MSE", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.2f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary to CSV (mean ± std of repetitions for each method & ratio).
    """
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"
        rows.append(row)

    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")


# =========================
# 7. Main Experiment Pipeline
# =========================
def run_improved_mixed_benchmark_experiment(
    input_csv_path: str,
    output_root_dir: str,
    sample_ratio: float = 0.1,
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    rep_count: int = 3,
    irt_draws: int = 1000,
    irt_tune: int = 1000,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
    target_accept: float = 0.95
):
    """
    Run improved mixed benchmark IRT experiment with better cross-benchmark information sharing.
    """
    print("Running improved mixed benchmark IRT experiment")
    
    # Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
        "params": os.path.join(output_root_dir, "05_model_parameters")
    }
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    # Build input data
    print("\n=== Step 1: Read & Preprocess Data ===")
    (N, J, M, model_names, unique_bench, bench_idx, J_list, Y_full) = build_inputs_from_csv(
        input_csv_path, sample_ratio, test_seed
    )
    if N == 0 or J == 0:
        print("No data available, exiting.")
        return

    question_names = [f"Q{i}" for i in range(J)]  # Simple question names
    print(f"Preprocessed data shape: {N} models × {J} questions")

    # Split fixed test set
    print("\n=== Step 2: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    # Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios},
        "Multi-benchmark IRT": {r: [] for r in train_ratios}
    }

    # Main experiment loop
    for train_ratio in train_ratios:
        print(f"\n=== Train Ratio = {train_ratio:.3f} ===")
        
        train_masks_rep = []
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep

            # Sample training subset
            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            # Baseline methods prediction + MSE calculation
            print("Running baseline methods...")
            # 1. Global mean
            pred_global = predict_global_mean(Y_full, train_mask)
            mse_global = calculate_mse(Y_full, pred_global, test_mask)
            # 2. Model mean
            pred_row = predict_row_mean(Y_full, train_mask)
            mse_row = calculate_mse(Y_full, pred_row, test_mask)
            # 3. Question mean
            pred_col = predict_col_mean(Y_full, train_mask)
            mse_col = calculate_mse(Y_full, pred_col, test_mask)

            # Run improved multi-benchmark IRT model
            print(f"Running improved multi-benchmark IRT model...")
            trace_multibench, params_multibench = fit_improved_irt_multibench(
                Y=Y_full,
                train_mask=train_mask,
                bench_idx=bench_idx,
                M=M,
                draws=irt_draws,
                tune=irt_tune,
                chains=irt_chains,
                cores=irt_cores,
                target_accept=target_accept
            )
            pred_multibench = predict_from_multibench_irt(params_multibench, bench_idx)
            mse_multibench = calculate_mse(Y_full, pred_multibench, test_mask)
            save_irt_trace(output_dirs["irt_trace"], trace_multibench, "multibench", train_ratio, rep)
            print(f"Multi-benchmark IRT MSE: {mse_multibench:.6f}")

            # Record MSE for current repetition
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            mse_summary["Multi-benchmark IRT"][train_ratio].append(mse_multibench)

            # Save sample-level predictions
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y_full,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_multibench=pred_multibench
            )

            # Print current repetition results
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            print(f"Multi-benchmark IRT: {mse_multibench:.6f}")

        # Record data split for current train ratio
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

        # Save model parameters for the last repetition
        save_model_parameters(
            params_dir=output_dirs["params"],
            params=params_multibench,
            model_names=model_names,
            question_names=question_names,
            unique_bench=unique_bench,
            trace=trace_multibench
        )

    # Generate final results
    print("\n=== Generate Final Results ===")
    # Save MSE summary
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)

    # Generate MSE comparison plot
    plot_mse_dict = {}
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    plot_mse_dict["Multi-benchmark IRT"] = [np.mean(mse_summary["Multi-benchmark IRT"][r]) for r in train_ratios]
    
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)

    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")


# =========================
# 8. Main Function
# =========================
if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Run improved mixed benchmark IRT experiment")
    parser.add_argument("--input_csv", type=str,
                        default="data/merged_is_correct_matrix.csv",
                        help="Path to merged is_correct matrix CSV")
    parser.add_argument("--output_dir", type=str,
                        default="results/improved_mixed_benchmark",
                        help="Output directory")
    parser.add_argument("--sample_ratio", type=float, default=0.1,
                        help="Column sampling ratio within benchmark (0,1]")
    parser.add_argument("--test_ratio", type=float, default=0.05,
                        help="Test set ratio")
    parser.add_argument("--test_seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--train_ratios", type=float, nargs="+",
                        default=[1.0],
                        help="List of training data ratios")
    parser.add_argument("--rep_count", type=int, default=3,
                        help="Number of repetitions for each training ratio")
    parser.add_argument("--irt_draws", type=int, default=10,
                        help="MCMC draws")
    parser.add_argument("--irt_tune", type=int, default=10,
                        help="MCMC tune")
    parser.add_argument("--irt_chains", type=int, default=4,
                        help="MCMC chains")
    parser.add_argument("--irt_cores", type=int, default=None,
                        help="MCMC cores")
    parser.add_argument("--target_accept", type=float, default=0.95,
                        help="Target acceptance rate for NUTS sampler")
    
    args = parser.parse_args()
    
    run_improved_mixed_benchmark_experiment(
        input_csv_path=args.input_csv,
        output_root_dir=args.output_dir,
        sample_ratio=args.sample_ratio,
        test_ratio=args.test_ratio,
        test_seed=args.test_seed,
        train_ratios=args.train_ratios,
        rep_count=args.rep_count,
        irt_draws=args.irt_draws,
        irt_tune=args.irt_tune,
        irt_chains=args.irt_chains,
        irt_cores=args.irt_cores,
        target_accept=args.target_accept
    )
