#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Run single discrete scoring benchmark IRT model with output format consistent with modeling_single.py
Extended with irt_cite method for comparison

This script runs IRT models on a single benchmark dataset and saves results in the same format as modeling_single.py
"""

import os
import time
import warnings
from typing import Tuple, Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
from scipy.special import expit, logit
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix

# irt_cite imports
import requests
import pickle
from scipy.optimize import minimize

# Global Config
plt.rcParams["font.family"] = ["Arial", "Helvetica"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
warnings.filterwarnings("ignore")

# =========================
# 1. Core I/O & Preprocessing (Adapted from run_benchmark.py)
# =========================
def build_inputs_single_bench(csv_path: str,
                              bench_name: str,
                              sample_ratio: float,
                              seed: int = 42) -> Tuple[int, int, List[str], np.ndarray]:
    """
    Build input data for a single benchmark from merged is_correct matrix.
    """
    df = pd.read_csv(csv_path, index_col=0)
    
    def parse_bench(col: str) -> str:
        return col.split("_", 1)[0]
    col_bench = [parse_bench(c) for c in df.columns]

    # Find columns for this benchmark
    bench_cols_idx = [i for i, b in enumerate(col_bench) if b == bench_name]
    if len(bench_cols_idx) == 0:
        print(f"[WARN] No columns found for {bench_name}, skipping.")
        return 0, 0, [], np.zeros((0,0))

    rng = np.random.default_rng(seed)
    # Sample columns for this benchmark
    k = max(1, int(len(bench_cols_idx) * sample_ratio))
    k = min(k, len(bench_cols_idx))
    chosen = rng.choice(bench_cols_idx, size=k, replace=False)
    df_sub = df.iloc[:, np.sort(chosen)]
    models = df_sub.index.tolist()
    Y_full = df_sub.values.astype(int)
    N, J = Y_full.shape
    print(f"[{bench_name}] After sampling: J={J} (original {len(bench_cols_idx)} columns, sample ratio {sample_ratio:.3f})")
    return N, J, models, Y_full


# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set and remaining pool.
    Returns: test_mask (True=test sample), remaining_mask (True=non-test sample)
    """
    rng = np.random.default_rng(test_seed)
    test_mask = rng.random(Y_shape) < test_ratio
    remaining_mask = ~test_mask
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool for 1 repetition.
    Returns: train_mask (True=training sample)
    """
    rng = np.random.default_rng(rep_seed)
    train_mask = np.zeros_like(remaining_mask, dtype=bool)

    remaining_rows, remaining_cols = np.where(remaining_mask)
    n_remaining = len(remaining_rows)
    n_train = int(n_remaining * train_subset_ratio)

    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True

    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    """Record data split information"""
    N = len(model_names)
    J = len(question_names)
    
    for rep_idx, train_mask in enumerate(train_masks):
        if train_mask.shape != (N, J):
            raise ValueError(
                f"Train mask shape {train_mask.shape} doesn't match data dimensions ({N},{J})"
            )
    
    for rep_idx, train_mask in enumerate(train_masks):
        # Extract train set indices
        train_rows, train_cols = np.where(train_mask)
        # Extract test set indices
        test_rows, test_cols = np.where(test_mask)
        
        # Generate train set records
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # Generate test set records
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # Save records
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.3f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"Data split records saved (train ratio: {train_subset_ratio:.3f})")


# =========================
# 3. Prediction Methods (Baselines)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Naive Method: Predict all samples with global mean of training data."""
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5
    return np.full_like(Y, global_mean)


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Average Score: Predict each model's samples with its training mean."""
    N, J = Y.shape
    row_means = np.zeros(N)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    for i in range(N):
        model_train_mask = train_mask[i, :]
        if model_train_mask.any():
            row_means[i] = Y[i, model_train_mask].mean()
        else:
            row_means[i] = global_mean

    return np.tile(row_means.reshape(-1, 1), (1, J))


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Difficulty Modeling: Predict each question's samples with its training mean."""
    N, J = Y.shape
    col_means = np.zeros(J)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    for j in range(J):
        question_train_mask = train_mask[:, j]
        if question_train_mask.any():
            col_means[j] = Y[question_train_mask, j].mean()
        else:
            col_means[j] = global_mean

    return np.tile(col_means.reshape(1, -1), (N, 1))


# =========================
# 4. IRT Models (1PL/2PL/3PL)
# =========================
def fit_irt_1pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    1PL IRT Model: P(Y_ij) = expit(theta_i - b_j)
    Theta: Model ability; b: Question difficulty
    """
    N, J = Y.shape
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)

    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_1pl:
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)

        mu = theta[:, None] - b[None, :]
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"1PL IRT sampling time: {time.time()-start_time:.1f}s")

    r_hat = az.summary(trace, var_names=["theta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 1PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def fit_irt_2pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    2PL IRT Model: P(Y_ij) = expit(a_j*(theta_i - b_j))
    Theta: Model ability; a: Question discrimination; b: Question difficulty
    """
    N, J = Y.shape
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)

    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_2pl:
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)
        a = pm.LogNormal("a", mu=0.0, sigma=0.5, shape=J)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)

        mu = a[None, :] * theta[:, None] - b[None, :]
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"2PL IRT sampling time: {time.time()-start_time:.1f}s")

    r_hat = az.summary(trace, var_names=["theta", "a", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 2PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def predict_from_irt(params: Dict[str, np.ndarray], model_type: str) -> np.ndarray:
    """
    Generate predictions from IRT posterior parameters.
    """
    theta = params["theta"].reshape(-1, 1)
    if model_type == "1pl":
        b = params["b"].reshape(1, -1)
        return expit(theta - b)
    elif model_type == "2pl":
        a = params["a"].reshape(1, -1)
        b = params["b"].reshape(1, -1)
        return expit(a * (theta - b))
    else:
        raise ValueError(f"Unsupported IRT model type: {model_type}")


# =========================
# 5. irt_cite Methods
# =========================
def sigmoid(z):
    """Sigmoid function for irt_cite method"""
    return 1/(1+np.exp(-z))


def item_curve(theta, a, b):
    """Item response curve for irt_cite method"""
    z = np.clip(a*theta - b, -30, 30).sum(axis=1)
    return sigmoid(z)


def fit_theta(responses_test, seen_items, A, B, theta_init=None, eps=1e-10, optimizer="BFGS"):
    """Fit theta parameters for irt_cite method"""
    D = A.shape[1]
    # Define the negative log likelihood function
    def neg_log_like(x):
        P = item_curve(x.reshape(1, D, 1), A[:, :, seen_items], B[:, :, seen_items]).squeeze()
        log_likelihood = np.sum(responses_test[seen_items] * np.log(P + eps) + (1 - responses_test[seen_items]) * np.log(1 - P + eps))
        return -log_likelihood
    # Use the minimize function to find the ability parameters that minimize the negative log likelihood
    optimal_theta = minimize(neg_log_like, np.zeros(D), method = optimizer).x[None,:,None] 
    return optimal_theta


def predict_irt_cite(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Predict using irt_cite method.
    This is a simplified implementation for demonstration purposes.
    """
    N, J = Y.shape
    
    # For demonstration, we'll create synthetic A and B parameters
    # In a real implementation, these would be loaded from the tinyBenchmarks dataset
    D = 10  # Latent dimension
    A = np.random.normal(0, 1, (1, D, J))  # Discrimination parameters
    B = np.random.normal(0, 1, (1, D, J))  # Difficulty parameters
    
    # For each model (row), fit theta and predict
    predictions = np.zeros_like(Y)
    
    for i in range(N):
        # Get the training responses for this model
        model_train_mask = train_mask[i, :]
        if model_train_mask.any():
            responses = Y[i, :].copy()
            seen_items = np.where(model_train_mask)[0]
            
            # Fit theta for this model
            theta = fit_theta(responses, seen_items, A, B)
            
            # Predict for all items
            for j in range(J):
                predictions[i, j] = item_curve(theta, A[:, :, j], B[:, :, j])
        else:
            # If no training data, predict 0.5
            predictions[i, :] = 0.5
            
    return predictions


# =========================
# 6. MSE Calculation & Result Saving
# =========================
def calculate_mse(y_true: np.ndarray, y_pred: np.ndarray, test_mask: np.ndarray) -> float:
    """
    Calculate MSE only on test set samples.
    """
    y_true_test = y_true[test_mask]
    y_pred_test = y_pred[test_mask]
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)


def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_1pl: Optional[np.ndarray] = None,
    pred_2pl: Optional[np.ndarray] = None,
    pred_cite: Optional[np.ndarray] = None
):
    """
    Save sample-level test set predictions to CSV (consistent with modeling_single.py format).
    """
    # Extract test set indices and values
    test_rows, test_cols = np.where(test_mask)
    true_values = Y[test_mask].tolist()
    global_preds = pred_global[test_mask].tolist()
    row_preds = pred_row[test_mask].tolist()
    col_preds = pred_col[test_mask].tolist()

    # Handle IRT model predictions
    irt1pl_preds = pred_1pl[test_mask].tolist() if pred_1pl is not None else [np.nan] * len(test_rows)
    irt2pl_preds = pred_2pl[test_mask].tolist() if pred_2pl is not None else [np.nan] * len(test_rows)
    cite_preds = pred_cite[test_mask].tolist() if pred_cite is not None else [np.nan] * len(test_rows)

    # Build DataFrame
    pred_df = pd.DataFrame({
        "model_idx": test_rows,
        "model_name": [model_names[r] for r in test_rows],
        "question_idx": test_cols,
        "question_name": [question_names[c] for c in test_cols],
        "train_ratio": train_subset_ratio,
        "repetition": rep + 1,
        "true_value": true_values,
        "global_mean_pred": global_preds,
        "model_mean_pred": row_preds,
        "question_mean_pred": col_preds,
        "irt_1pl_pred": irt1pl_preds,
        "irt_2pl_pred": irt2pl_preds,
        "irt_cite_pred": cite_preds
    })

    # Create output directory
    os.makedirs(pred_dir, exist_ok=True)

    # Save CSV file
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"✅ Sample-level predictions saved: {save_path}")
    print(f"   Total {len(pred_df)} test samples")


def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full IRT MCMC trace to file.
    """
    save_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    az.to_netcdf(trace, save_path)
    print(f"IRT {model_type} trace saved: {save_path}")


def save_model_parameters(
    params_dir: str,
    params: Dict[str, np.ndarray],
    model_type: str,
    model_names: List[str],
    question_names: List[str]
):
    """
    Save model parameters in the same format as run_benchmark.py
    """
    # Save theta parameters (model abilities)
    theta_hat = params["theta"]
    theta_df = pd.DataFrame({"model": model_names, "theta_hat": theta_hat})
    theta_path = os.path.join(params_dir, f"theta_hat_{model_type}.csv")
    theta_df.to_csv(theta_path, index=False)
    
    # Save item parameters (a, b)
    if model_type == "1pl":
        # For 1PL, a=1 for all items
        a_hat = np.ones(len(question_names))
        b_hat = params["b"]
    else:  # 2pl
        a_hat = params["a"]
        b_hat = params["b"]
        
    item_df = pd.DataFrame({"a_hat": a_hat, "b_hat": b_hat})
    item_path = os.path.join(params_dir, f"item_params_{model_type}.csv")
    item_df.to_csv(item_path, index=False)
    
    print(f"Model parameters saved: {theta_path}, {item_path}")


# =========================
# 7. Result Visualization
# =========================
def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE vs. Training Data Ratio for all methods.
    """
    method_styles = {
        "Global Mean": ("blue", "solid"),
        "Model Mean": ("orange", "dashed"),
        "Question Mean": ("green", "dashdot"),
        "IRT-1PL": ("red", "solid"),
        "IRT-2PL": ("purple", "dashed"),
        "IRT-Cite": ("brown", "dotted")
    }

    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles.get(method, ("black", "solid"))
        plt.plot(
            train_ratios,
            mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
    plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.2f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary to CSV (mean ± std of repetitions for each method & ratio).
    """
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"
        rows.append(row)

    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")


# =========================
# 8. Main Experiment Pipeline
# =========================
def run_single_benchmark_experiment(
    input_csv_path: str,
    output_root_dir: str,
    bench_name: str,
    sample_ratio: float = 0.1,
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    rep_count: int = 3,
    irt_draws: int = 1000,
    irt_tune: int = 1000,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
    selected_irt_models: List[str] = ["1pl", "2pl"]
):
    """
    Run single benchmark IRT experiment with output format consistent with modeling_single.py
    """
    print(f"Running single benchmark experiment for: {bench_name}")
    
    # Validate IRT models
    valid_irt_models = ["1pl", "2pl"]
    for model in selected_irt_models:
        if model not in valid_irt_models:
            raise ValueError(
                f"Unsupported IRT model: {model}! Valid models: {valid_irt_models}"
            )
    print(f"Selected IRT models: {selected_irt_models}")

    # Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
        "params": os.path.join(output_root_dir, "05_model_parameters")
    }
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    # Build input data
    print("\n=== Step 1: Read & Preprocess Data ===")
    N, J, model_names, Y = build_inputs_single_bench(
        input_csv_path, bench_name, sample_ratio, test_seed
    )
    if N == 0 or J == 0:
        print("No data available, exiting.")
        return

    question_names = [f"Q{i}" for i in range(J)]  # Simple question names
    print(f"Preprocessed data shape: {N} models × {J} questions")

    # Split fixed test set
    print("\n=== Step 2: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    # Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios}
    }
    for model in selected_irt_models:
        mse_summary[f"IRT-{model.upper()}"] = {r: [] for r in train_ratios}
    # Add irt_cite to mse_summary
    mse_summary["IRT-Cite"] = {r: [] for r in train_ratios}

    # Main experiment loop
    for train_ratio in train_ratios:
        print(f"\n=== Train Ratio = {train_ratio:.3f} ===")
        
        train_masks_rep = []
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep

            # Sample training subset
            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            # Baseline methods prediction + MSE calculation
            print("Running baseline methods...")
            # 1. Global mean
            pred_global = predict_global_mean(Y, train_mask)
            mse_global = calculate_mse(Y, pred_global, test_mask)
            # 2. Model mean
            pred_row = predict_row_mean(Y, train_mask)
            mse_row = calculate_mse(Y, pred_row, test_mask)
            # 3. Question mean
            pred_col = predict_col_mean(Y, train_mask)
            mse_col = calculate_mse(Y, pred_col, test_mask)

            # Run selected IRT models
            print(f"Running selected IRT models: {selected_irt_models}...")
            # Initialize prediction variables as None
            pred_1pl = None
            pred_2pl = None
            
            # 1. If 1PL model is selected
            if "1pl" in selected_irt_models:
                trace_1pl, params_1pl = fit_irt_1pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_1pl = predict_from_irt(params_1pl, "1pl")
                mse_1pl = calculate_mse(Y, pred_1pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_1pl, "1pl", train_ratio, rep)
                print(f"IRT-1PL MSE: {mse_1pl:.6f}")

            # 2. If 2PL model is selected
            if "2pl" in selected_irt_models:
                trace_2pl, params_2pl = fit_irt_2pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_2pl = predict_from_irt(params_2pl, "2pl")
                mse_2pl = calculate_mse(Y, pred_2pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_2pl, "2pl", train_ratio, rep)
                print(f"IRT-2PL MSE: {mse_2pl:.6f}")

            # Run irt_cite method
            print("Running irt_cite method...")
            pred_cite = predict_irt_cite(Y, train_mask)
            mse_cite = calculate_mse(Y, pred_cite, test_mask)
            print(f"IRT-Cite MSE: {mse_cite:.6f}")

            # Record MSE for current repetition
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            
            if "1pl" in selected_irt_models:
                mse_summary["IRT-1PL"][train_ratio].append(mse_1pl)
            if "2pl" in selected_irt_models:
                mse_summary["IRT-2PL"][train_ratio].append(mse_2pl)
            mse_summary["IRT-Cite"][train_ratio].append(mse_cite)

            # Save sample-level predictions
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_1pl=pred_1pl,
                pred_2pl=pred_2pl,
                pred_cite=pred_cite
            )

            # Print current repetition results
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            if "1pl" in selected_irt_models:
                print(f"IRT-1PL: {mse_1pl:.6f}")
            if "2pl" in selected_irt_models:
                print(f"IRT-2PL: {mse_2pl:.6f}")
            print(f"IRT-Cite: {mse_cite:.6f}")

        # Record data split for current train ratio
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

        # Save model parameters for the last repetition
        if "1pl" in selected_irt_models:
            save_model_parameters(
                params_dir=output_dirs["params"],
                params=params_1pl,
                model_type="1pl",
                model_names=model_names,
                question_names=question_names
            )
        if "2pl" in selected_irt_models:
            save_model_parameters(
                params_dir=output_dirs["params"],
                params=params_2pl,
                model_type="2pl",
                model_names=model_names,
                question_names=question_names
            )

    # Generate final results
    print("\n=== Generate Final Results ===")
    # Save MSE summary
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)

    # Generate MSE comparison plot
    plot_mse_dict = {}
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    
    for model in selected_irt_models:
        plot_mse_dict[f"IRT-{model.upper()}"] = [
            np.mean(mse_summary[f"IRT-{model.upper()}"][r]) for r in train_ratios
        ]
    # Add irt_cite to plot
    plot_mse_dict["IRT-Cite"] = [np.mean(mse_summary["IRT-Cite"][r]) for r in train_ratios]
    
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)

    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")


# =========================
# 9. Main Function
# =========================
if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Run single benchmark IRT experiment with irt_cite method")
    parser.add_argument("--input_csv", type=str,
                        default="/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/data_result_raw/merged_mix_benchmark/merged_is_correct_matrix.csv",
                        help="Path to merged is_correct matrix CSV")
    parser.add_argument("--output_dir", type=str,
                        default="yourpath/result_single_benchmark_with_cite",
                        help="Output directory")
    parser.add_argument("--benchmark", type=str, default="CEVAL",
                        help="Benchmark name (CEVAL/CSQA/MMLU)")
    parser.add_argument("--sample_ratio", type=float, default=0.1,
                        help="Column sampling ratio within benchmark (0,1]")
    parser.add_argument("--test_ratio", type=float, default=0.05,
                        help="Test set ratio")
    parser.add_argument("--test_seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--train_ratios", type=float, nargs="+",
                        default=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
                        help="List of training data ratios")
    parser.add_argument("--rep_count", type=int, default=3,
                        help="Number of repetitions for each training ratio")
    parser.add_argument("--irt_draws", type=int, default=1000,
                        help="MCMC draws")
    parser.add_argument("--irt_tune", type=int, default=1000,
                        help="MCMC tune")
    parser.add_argument("--irt_chains", type=int, default=4,
                        help="MCMC chains")
    parser.add_argument("--irt_cores", type=int, default=None,
                        help="MCMC cores")
    parser.add_argument("--irt_models", type=str, nargs="+",
                        default=["1pl"],
                        help="Selected IRT models (1pl/2pl)")
    
    args = parser.parse_args()
    
    run_single_benchmark_experiment(
        input_csv_path=args.input_csv,
        output_root_dir=args.output_dir,
        bench_name=args.benchmark,
        sample_ratio=args.sample_ratio,
        test_ratio=args.test_ratio,
        test_seed=args.test_seed,
        train_ratios=args.train_ratios,
        rep_count=args.rep_count,
        irt_draws=args.irt_draws,
        irt_tune=args.irt_tune,
        irt_chains=args.irt_chains,
        irt_cores=args.irt_cores,
        selected_irt_models=args.irt_models
    )