#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
XSUM Dataset F1-Score Prediction Experiment
Experiment Design:
- Fixed Test Set: 5% of total data (randomly sampled once)
- Training Set: 10%, 20%, ..., 100% of remaining 95% data (variable ratio)
- Methods: Naive(Global Mean), Average Score(Row Mean), Difficulty(Col Mean), IRT(1PL/2PL/3PL)
- Repetitions: 3 independent runs for each training ratio
- Outputs: Data split records, full MCMC traces, sample-wise predictions, MSE metrics/plots
"""

import os
import time
import pickle
import warnings
from typing import Tuple, Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
from scipy.special import expit, logit

# Global Config (Critical for Consistency & Plot Requirements)
plt.rcParams["font.family"] = ["Arial", "Helvetica"]  # Force use of English fonts (avoid Chinese garbled text + meet experimental chart requirements)
plt.rcParams["axes.unicode_minus"] = False  # Display negative signs correctly
plt.rcParams["figure.dpi"] = 150  # Chart resolution (150DPI balances clarity and file size)
warnings.filterwarnings("ignore")  # Suppress irrelevant warnings (such as pymc sampling notifications)

# =========================
# 1. Core I/O & Preprocessing
# =========================
def robust_read_csv(path: str) -> pd.DataFrame:
    """Read CSV with multiple encoding attempts (handle different file formats)."""
    encodings = ["utf-8", "utf-8-sig", "gbk", "latin1"]  # Cover common encoding formats
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    raise ValueError(f"Failed to read CSV: {str(last_err)}")


def preprocess_matrix(
    df: pd.DataFrame,
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess input data (rows=questions, cols=models) into model-centric format:
    1. Filter questions with too many zeros
    2. Handle 0/1 values (avoid logit overflow)
    3. Transpose to (rows=models, cols=questions)
    Returns: preprocessed_df, Y_array, model_names, question_names
    """
    work = df.copy()  # Copy original data to avoid modifying the source file

    # Step 1: Set question ID as index (if provided)
    if question_id_col is not None:
        work = work.set_index(question_id_col)  # Set question ID column as index (for easier question tracing later)
        original_question_names = work.index.tolist()  # Record original question names

    # Step 2: Filter questions with excessive zeros (e.g., >30% zeros)
    zero_ratio = (work == 0).sum(axis=1) / max(1, work.shape[1])  # Calculate zero value ratio for each question
    work = work[zero_ratio <= zero_row_ratio_threshold].copy()  # Remove questions with excessive zero values (no discrimination)
    filtered_question_names = work.index.tolist()
    print(f"Filtered questions: {len(original_question_names)} → {len(filtered_question_names)}")

    # Step 3: Numeric conversion + fill missing values
    work = work.apply(pd.to_numeric, errors="coerce").fillna(0.0)  # Convert to numeric, fill missing values with 0 (reasonable assumption: unrecorded as 0 points)

    # Step 4: Clip to (epsilon, 1-epsilon) (avoid logit(0) or logit(1) overflow)
    # Reason: IRT models need to use logit transformation on F1 scores (0→-inf, 1→+inf would cause computational crashes), so we clip to 1e-6/1-1e-6
    work = work.replace(0.0, epsilon).clip(lower=epsilon, upper=1 - epsilon)

    # Step 5: Transpose (rows=models, cols=questions)
    # Original data: rows=questions, columns=models → After transpose: rows=models, columns=questions (conforms to IRT standard format of "subjects (models) × items (questions)")
    work_transposed = work.T
    model_names = work_transposed.index.tolist()  # Record model names (for result tracing later)
    Y = work_transposed.values  # Convert to numpy array (for matrix operations)

    return work_transposed, Y, model_names, filtered_question_names

# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set (5%) and remaining pool (95%)—only run once.
    Returns: test_mask (True=test sample), remaining_mask (True=non-test sample)
    """
    rng = np.random.default_rng(test_seed)  # Fix random seed (ensure test set is unique and reproducible)
    test_mask = rng.random(Y_shape) < test_ratio  # Generate boolean matrix with same shape as data (True=test samples)
    remaining_mask = ~test_mask  # Remaining 95% data (for subsequent sampling of different training set proportions)
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,  # Ratio of remaining pool (e.g., 0.1 = 10% of 95%)
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool (exclude fixed test set) for 1 repetition.
    Returns: train_mask (True=training sample)
    """
    rng = np.random.default_rng(rep_seed)  # Use different seeds for each repetition (ensure training sets are independent)
    train_mask = np.zeros_like(remaining_mask, dtype=bool)  # Initialize training set mask

    # Step 1: Get indices of "remaining pool" (i.e., non-test sample positions)
    remaining_rows, remaining_cols = np.where(remaining_mask)  # rows=model indices, cols=question indices
    n_remaining = len(remaining_rows)  # Total samples in remaining pool
    n_train = int(n_remaining * train_subset_ratio)  # Number of training samples to sample this time

    # Step 2: Sampling without replacement (avoid same "model-question" pairs entering training set repeatedly)
    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    # Mark selected samples as training set
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True

    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    # Key: Add length validation to avoid subsequent index overflow
    N = len(model_names)  # Number of models = length of model_names
    J = len(question_names)  # Number of questions = length of question_names
    for rep_idx, train_mask in enumerate(train_masks):
        # Check if training set mask shape matches name lists
        if train_mask.shape != (N, J):
            raise ValueError(
                f"Training set mask shape {train_mask.shape} does not match data dimensions ({N},{J})!"
                f"model_names length={N}, question_names length={J}, please check data preprocessing logic"
            )
    
    # Subsequent normal processing of data split records...
    for rep_idx, train_mask in enumerate(train_masks):
        # 1. Extract row (model) and column (question) indices for training set
        train_rows, train_cols = np.where(train_mask)
        # 2. Extract row (model) and column (question) indices for test set
        test_rows, test_cols = np.where(test_mask)
        
        # Generate training set records (at this point index r is definitely within model_names range)
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],  # Will not exceed range anymore
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],  # Will not exceed range anymore
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # Generate test set records
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
    })
        
        # Save records
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.1f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"Data split records saved (training ratio: {train_subset_ratio:.1f})")

# =========================
# 3. Prediction Methods (Baselines + IRT Models)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Naive Method: Predict all samples with global mean of training data."""
    # Calculate mean of all samples in training set (if training set is empty, use 0.5 as default value to avoid errors)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5
    # Generate prediction matrix with same shape as original data (all values are global mean)
    return np.full_like(Y, global_mean)


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Average Score: Predict each model's samples with its training mean."""
    N, J = Y.shape  # N=number of models, J=number of questions
    row_means = np.zeros(N)  # Store training set mean for each model
    # Global mean as fallback (if a model has no training samples, use global mean as substitute)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    # Iterate through each model, calculate its mean in the training set
    for i in range(N):
        model_train_mask = train_mask[i, :]  # Training set mask for the i-th model
        if model_train_mask.any():
            row_means[i] = Y[i, model_train_mask].mean()  # Training mean for this model
        else:
            row_means[i] = global_mean  # Use global mean when no training samples
    return np.tile(row_means.reshape(-1, 1), (1, J))


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Difficulty Modeling: Predict each question's samples with its training mean."""
    N, J = Y.shape  # N=number of models, J=number of questions
    col_means = np.zeros(J)  # Store training set mean for each question
    # Global mean as fallback (if a question has no training samples, use global mean as substitute)
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    # Iterate through each question, calculate its mean in the training set
    for j in range(J):
        question_train_mask = train_mask[:, j]  # Training set mask for the j-th question
        if question_train_mask.any():
            col_means[j] = Y[question_train_mask, j].mean()  # Training mean for this question
        else:
            col_means[j] = global_mean  # Use global mean when no training samples
    return np.tile(col_means.reshape(1, -1), (N, 1))

# -------------------------
# IRT Models (1PL/2PL/3PL) with Full Trace Saving
# -------------------------
def fit_irt_1pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    1PL IRT Model: P(Y_ij) = expit(theta_i - b_j)
    Theta: Model ability; b: Question difficulty
    Returns: Full MCMC trace (for post-hoc analysis), posterior mean parameters
    """
    N, J = Y.shape
    # Clip F1 scores (avoid logit overflow, consistent with preprocessing)
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)  # Convert to logit scale (conforms to IRT model's normality assumption)

    # Automatically adapt CPU cores (avoid over-occupying resources)
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    # Build 1PL model (based on PyMC Bayesian framework)
    with pm.Model() as irt_1pl:
        # 1. Define parameter priors (based on common assumptions in IRT field)
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # Model ability: Normal distribution (mean 0, std 1)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # Question difficulty: Normal distribution (mean 0, std 1)

        # 2. Linear predictor (expectation on logit scale)
        mu = theta[:, None] - b[None, :]  # Broadcasting operation: (N,1) - (1,J) → (N,J)

        # 3. Likelihood function (fit only with training set data)
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC sampling (core step: estimate parameter posterior distribution)
        start_time = time.time()
        trace = pm.sample(
            draws=draws,          # Samples per chain (1000 effective samples to ensure stable estimation)
            tune=tune,            # Warmup iterations (1000 iterations to let sampler converge to target distribution)
            chains=chains,        # Parallel chains (4 chains for convergence checking)
            cores=cores,          # Parallel cores (match chain count to accelerate sampling)
            target_accept=0.95,   # Target acceptance rate (high acceptance rate improves sampling efficiency, reduces autocorrelation)
            progressbar=True,    # Disable progress bar (reduce output interference)
            return_inferencedata=True  # Return trace in ArviZ format (for subsequent analysis)
        )
        print(f"1PL IRT sampling time: {time.time()-start_time:.1f}s")

    # Convergence check (key: r_hat < 1.01 indicates sampling convergence, results are reliable)
    r_hat = az.summary(trace, var_names=["theta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 1PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01 (potential convergence issue)")

    # Extract parameter posterior means (for generating predictions)
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params

def fit_irt_2pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    2PL IRT Model: P(Y_ij) = expit(a_j*(theta_i - b_j))
    Theta: Model ability; a: Question discrimination; b: Question difficulty
    Returns: Full MCMC trace, posterior mean parameters
    """
    N, J = Y.shape
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)

    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_2pl:
        # 1. Parameter priors (compared to 1PL, added "discrimination a_j")
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # Model ability (same as 1PL)
        a = pm.LogNormal("a", mu=0.0, sigma=0.5, shape=J)        # Question discrimination: LogNormal ensures a_j>0
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # Question difficulty (same as 1PL)

        # 2. Linear predictor (discrimination a_j weighted ability and difficulty difference)
        mu = a[None, :] * theta[:, None] - b[None, :]  # (1,J)*(N,J) → (N,J)

        # 3. Likelihood function (same as 1PL, fit only with training set)
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC sampling
        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"2PL IRT sampling time: {time.time()-start_time:.1f}s")

    # Convergence check (added check for discrimination parameter a)
    r_hat = az.summary(trace, var_names=["theta", "a", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 2PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    # Extract posterior means
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def fit_irt_3pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    3PL IRT Model: P(Y_ij) = c_j + (1-c_j)*expit(a_j*(theta_i - b_j))
    Theta: Model ability; a: Question discrimination; b: Question difficulty; c: Guessing parameter
    Returns: Full MCMC trace, posterior mean parameters
    """
    N, J = Y.shape  # N=number of models, J=number of questions
    # Data preprocessing: clip extreme values + logit transformation (avoid expit/logit overflow)
    eps = 1e-6
    Y_clip = np.clip(Y, eps, 1 - eps)
    logit_Y = logit(Y_clip)

    # Automatically adapt CPU cores
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_3pl:
        # 1. Parameter priors (completely following your settings)
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)          # Model ability
        a = pm.Normal("a", mu=1.0, sigma=0.5, shape=J)                  # Discrimination (your prior setting)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)                  # Difficulty (your prior setting)
        c = pm.Beta("c", alpha=1.0, beta=5.0, shape=J)                   # 3PL core: guessing parameter
        sigma = pm.HalfNormal("sigma", sigma=1.0, shape=J)              # Dispersion (your prior setting)
        # Optional: If range constraints are needed, enable the commented line below (your alternative setting)
        # sigma = pm.Uniform("sigma", lower=0.5, upper=1.5, shape=J)

        # 2. 3PL mean calculation (completely following your mu calculation logic, added guessing term)
        # Base mean term: mu_base = a_j*theta_i - b_j (your original formula)
        mu_base = a[None, :] * theta[:, None] - b[None, :]
        # 3PL probability transformation: P = c_j + (1-c_j)*expit(mu_base)
        prob = c[None, :] + (1 - c[None, :]) * pm.math.sigmoid(mu_base)  # sigmoid=expit, compatible with PyMC tensors
        prob_clip = pm.math.clip(prob, eps, 1 - eps)                     # Avoid logit overflow
        mu = pm.math.logit(prob_clip)                                    # Convert to logit scale (adapt to likelihood)

        # 3. Data shape processing (completely following your flattening logic, using only training set data)
        # Extract logit observations, means, and dispersions for training set
        logit_y_obs = logit_Y[train_mask].flatten()  # Training set observations → 1D array (your processing method)
        mu_obs = mu[train_mask].flatten()            # Training set means → 1D array (your processing method)
        # Dispersion expansion: Each question repeated N times → match 1D array length (your processing method)
        sigma_obs = sigma.repeat(N)[train_mask.flatten()]  

        # (Optional) Print shape check (for debugging, following your debugging logic)
        print(f"Training set logit observation shape: {logit_y_obs.shape}")
        print(f"Training set mean shape: {mu_obs.shape}")
        print(f"Training set dispersion shape: {sigma_obs.shape}")

        # 4. Likelihood function (logit-normal distribution, completely following your settings)
        pm.Normal("y_obs", mu=mu_obs, sigma=sigma_obs, observed=logit_y_obs)

        # 5. MCMC sampling (retain original logic to ensure consistent input/output)
        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"3PL IRT sampling time: {time.time()-start_time:.1f}s")

    # 6. Convergence check (added sigma parameter check to ensure all parameters converge)
    r_hat = az.summary(trace, var_names=["theta", "a", "b", "c", "sigma"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 3PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    # 7. Extract posterior means (including sigma, retain original output structure)
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,      # Following your a parameter
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values,      # Following your b parameter
        "c": trace.posterior["c"].mean(dim=["chain", "draw"]).values,      # 3PL guessing parameter
        "sigma": trace.posterior["sigma"].mean(dim=["chain", "draw"]).values  # Following your sigma parameter
    }
    return trace, params


def predict_from_irt(params: Dict[str, np.ndarray], model_type: str) -> np.ndarray:
    """
    Generate F1 predictions from IRT posterior parameters.
    Input: params (theta/a/b/c), model_type ("1pl"/"2pl"/"3pl")
    Output: Predictions (N,J) → F1 scores in [0,1]
    """
    theta = params["theta"].reshape(-1, 1)  # (N,1): Ability value for each model
    if model_type == "1pl":
        b = params["b"].reshape(1, -1)      # (1,J): Difficulty for each question
        return expit(theta - b)             # 1PL prediction formula: expit(θ_i - b_j)
    elif model_type == "2pl":
        a = params["a"].reshape(1, -1)      # (1,J): Discrimination for each question
        b = params["b"].reshape(1, -1)
        return expit(a * (theta - b))       # 2PL prediction formula: expit(a_j*(θ_i - b_j))
    elif model_type == "3pl":
        a = params["a"].reshape(1, -1)
        b = params["b"].reshape(1, -1)
        c = params["c"].reshape(1, -1)      # (1,J): Guessing parameter for each question
        return c + (1 - c) * expit(a * (theta - b))  # 3PL prediction formula
    else:
        raise ValueError(f"Unsupported IRT model type: {model_type}")

# =========================
# 4. MSE Calculation & Result Saving
# =========================
def calculate_mse(y_true: np.ndarray, y_pred: np.ndarray, test_mask: np.ndarray) -> float:
    """
    Calculate MSE (Mean Squared Error) only on test set samples.
    Input:
        y_true: Original F1 scores (N,J)
        y_pred: Predicted F1 scores (N,J)
        test_mask: Boolean matrix (True=test sample)
    Output: MSE value (lower = better prediction)
    """
    # Extract only the true values and predictions for the test set
    y_true_test = y_true[test_mask]
    y_pred_test = y_pred[test_mask]
    # Calculate MSE: mean squared error
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)  # Keep 6 decimal places for easier comparison


def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_1pl: Optional[np.ndarray] = None,  
    pred_2pl: Optional[np.ndarray] = None,  
    pred_3pl: Optional[np.ndarray] = None   
):
    """
    Save sample-level test set predictions to CSV (for post-hoc MSE calculation).
    Each row = 1 test sample (model-question pair), columns = true value + all methods' predictions.
    """
    test_rows, test_cols = np.where(test_mask)  
    true_values = Y[test_mask].tolist()      
    global_preds = pred_global[test_mask].tolist() 
    row_preds = pred_row[test_mask].tolist()       
    col_preds = pred_col[test_mask].tolist()        

    irt1pl_preds = pred_1pl[test_mask].tolist() if pred_1pl is not None else [np.nan] * len(test_rows)

    irt2pl_preds = pred_2pl[test_mask].tolist() if pred_2pl is not None else [np.nan] * len(test_rows)

    irt3pl_preds = pred_3pl[test_mask].tolist() if pred_3pl is not None else [np.nan] * len(test_rows)

    pred_df = pd.DataFrame({
        "model_idx": test_rows,                 
        "model_name": [model_names[r] for r in test_rows], 
        "question_idx": test_cols,             
        "question_name": [question_names[c] for c in test_cols], 
        "train_ratio": train_subset_ratio,       
        "repetition": rep + 1,        
        "true_value": true_values,               
        "global_mean_pred": global_preds,        
        "model_mean_pred": row_preds,            
        "question_mean_pred": col_preds,         
        "irt_1pl_pred": irt1pl_preds,            
        "irt_2pl_pred": irt2pl_preds,            
        "irt_3pl_pred": irt3pl_preds             
    })

    # 4. Create output directory (ensure directory exists)
    os.makedirs(pred_dir, exist_ok=True)

    # 5. Save CSV file (filename includes training ratio and repetition count for easy differentiation)
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"Sample predictions saved: {save_path}")
    print(f"Total samples: {len(pred_df)}")


def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full IRT MCMC trace to file (supports post-hoc analysis: posterior median/95%CI).
    Format: NetCDF (ArviZ native format, easy to load later)
    """
    save_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.1f}_rep{rep+1}.nc"
    )
    az.to_netcdf(trace, save_path)
    print(f"IRT {model_type} trace saved: {save_path}")


# =========================
# 5. Result Visualization (MSE Comparison Plot)
# =========================
def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE vs. Training Data Ratio for all methods (no Chinese).
    Input:
        mse_dict: Key=method name, Value=MSE list (length=len(train_ratios))
        train_ratios: List of training data ratios (e.g., [0.1, 0.2, ..., 1.0])
        save_path: Path to save plot (PNG)
    """
    method_styles = {
        "Global Mean": ("blue", "solid"),
        "Model Mean": ("orange", "dashed"),
        "Question Mean": ("green", "dashdot"),
        "IRT-1PL": ("red", "solid"),
        "IRT-2PL": ("purple", "dashed"),
        "IRT-3PL": ("brown", "dashdot")
    }

    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles[method]
        plt.plot(
            train_ratios,
            mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
    plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)  
    plt.grid(True, alpha=0.3)  
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10) 
    plt.yticks(fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary to CSV (mean ± std of 3 repetitions for each method & ratio).
    """
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"  
        rows.append(row)

    # Save CSV
    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")

# =========================
# 6. Experiment Main Pipeline
# =========================
def run_experiment(
    input_csv_path: str,
    output_root_dir: str,
    question_id_col: Optional[str or int] = 0, 
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    rep_count: int = 3,  
    irt_draws: int = 1000,
    irt_tune: int = 1000,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
    selected_irt_models: List[str] = ["1pl", "2pl", "3pl"] 
):
    # New: Validate input IRT models
    valid_irt_models = ["1pl", "2pl", "3pl"]
    for model in selected_irt_models:
        if model not in valid_irt_models:
            raise ValueError(
                f"Unsupported IRT model: {model}! Please select from the following valid models: {valid_irt_models}"
            )
    print(f"Selected IRT models to run: {selected_irt_models}")

    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"), 
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"), 
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"), 
        "metrics": os.path.join(output_root_dir, "04_metrics"),  
    }
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    print("\n=== Step 1/6: Read & Preprocess Data ===")
    raw_df = robust_read_csv(input_csv_path)
    preprocessed_df, Y, model_names, question_names = preprocess_matrix(
        df=raw_df,
        question_id_col=question_id_col,
        zero_row_ratio_threshold=0.3,
        epsilon=1e-6
    )
    N, J = Y.shape
    print(f"Preprocessed data shape: {N} models × {J} questions")

    print("\n=== Step 2/6: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios}
    }
    for model in selected_irt_models:
        mse_summary[f"IRT-{model.upper()}"] = {r: [] for r in train_ratios}

    for train_ratio in train_ratios:
        print(f"\n=== Step 3/6: Train Ratio = {train_ratio:.3f} ===")
        
        train_masks_rep = []
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep  

            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            print("Running baseline methods...")
            # 1. Global mean
            pred_global = predict_global_mean(Y, train_mask)
            mse_global = calculate_mse(Y, pred_global, test_mask)
            # 2. Model mean
            pred_row = predict_row_mean(Y, train_mask)
            mse_row = calculate_mse(Y, pred_row, test_mask)
            # 3. Question mean
            pred_col = predict_col_mean(Y, train_mask)
            mse_col = calculate_mse(Y, pred_col, test_mask)

            # -------------------------
            # Substep 5.3: Run only selected IRT models
            print(f"Running selected IRT models: {selected_irt_models}...")
            # Initialize prediction result variables to None
            pred_1pl = None
            pred_2pl = None
            pred_3pl = None
            
            # 1. If 1PL model is selected
            if "1pl" in selected_irt_models:
                trace_1pl, params_1pl = fit_irt_1pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_1pl = predict_from_irt(params_1pl, "1pl")
                mse_1pl = calculate_mse(Y, pred_1pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_1pl, "1pl", train_ratio, rep)
                print(f"IRT-1PL MSE: {mse_1pl:.6f}")

            # 2. If 2PL model is selected
            if "2pl" in selected_irt_models:
                trace_2pl, params_2pl = fit_irt_2pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_2pl = predict_from_irt(params_2pl, "2pl")
                mse_2pl = calculate_mse(Y, pred_2pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_2pl, "2pl", train_ratio, rep)
                print(f"IRT-2PL MSE: {mse_2pl:.6f}")

            # 3. If 3PL model is selected
            if "3pl" in selected_irt_models:
                trace_3pl, params_3pl = fit_irt_3pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_3pl = predict_from_irt(params_3pl, "3pl")
                mse_3pl = calculate_mse(Y, pred_3pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_3pl, "3pl", train_ratio, rep)
                print(f"IRT-3PL MSE: {mse_3pl:.6f}")

            # -------------------------
            # Substep 5.4: Record MSE for current repetition
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            
            # Record MSE for selected IRT models
            if "1pl" in selected_irt_models:
                mse_summary["IRT-1PL"][train_ratio].append(mse_1pl)
            if "2pl" in selected_irt_models:
                mse_summary["IRT-2PL"][train_ratio].append(mse_2pl)
            if "3pl" in selected_irt_models:
                mse_summary["IRT-3PL"][train_ratio].append(mse_3pl)

            # -------------------------
            # Substep 5.5: Save sample-level predictions for current repetition (fixed parameter passing)
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_1pl=pred_1pl,  # Only pass calculated prediction results, unselected models are None
                pred_2pl=pred_2pl,
                pred_3pl=pred_3pl
            )

            # Print MSE results for current repetition
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            if "1pl" in selected_irt_models:
                print(f"IRT-1PL: {mse_1pl:.6f}")
            if "2pl" in selected_irt_models:
                print(f"IRT-2PL: {mse_2pl:.6f}")
            if "3pl" in selected_irt_models:
                print(f"IRT-3PL: {mse_3pl:.6f}")

        # -------------------------
        # Substep 5.6: Record data split for current training ratio
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

    # -------------------------
    # Step 6: Generate final MSE summary and visualization
    print("\n=== Step 4/6: Generate Final Results ===")
    # Save MSE summary table
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)

    # Generate MSE comparison plot
    plot_mse_dict = {}
    # Add baseline methods
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    # Add selected IRT models
    for model in selected_irt_models:
        plot_mse_dict[f"IRT-{model.upper()}"] = [
            np.mean(mse_summary[f"IRT-{model.upper()}"][r]) for r in train_ratios
        ]
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)

    # -------------------------
    # Step 7: Print experiment completion information
    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")


INPUT_CSV_PATH = "/public/big-leg/expriment/continuous_processed/response_matrix__bertscore_F1.csv"  # Input CSV path (rows=questions, columns=models)
OUTPUT_ROOT_DIR = "/public/big-leg/expriment/modeling_result_XSUM_F1"  # Output root directory (will be created automatically)
QUESTION_ID_COL = 'row_index'  # Question ID column (if first column of CSV is question ID, pass 0; if none, pass None)
TEST_SEED = 42  # Fixed test set seed (ensure test set reproducibility)
TEST_RATIO = 0.10
TRAIN_RATIOS = [0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1.00]
REP_COUNT = 1  # Number of repetitions for each training ratio
IRT_DRAWS = 10  # Sampling number per chain for IRT model
IRT_TUNE = 10  # Warmup iterations for IRT model
IRT_CHAINS = 256  # Parallel chains for IRT model
IRT_CORES = 256  # Parallel cores for IRT model (None=auto-adapt)
SELECTED_IRT_MODELS = ["1pl"]  # Can be modified to ["1pl"] or ["1pl", "2pl", "3pl"]

run_experiment(
    input_csv_path=INPUT_CSV_PATH,
    output_root_dir=OUTPUT_ROOT_DIR,
    question_id_col=QUESTION_ID_COL,
    test_ratio=TEST_RATIO,
    test_seed=TEST_SEED,
    train_ratios=TRAIN_RATIOS,
    rep_count=REP_COUNT,
    irt_draws=IRT_DRAWS,
    irt_tune=IRT_TUNE,
    irt_chains=IRT_CHAINS,
    irt_cores=IRT_CORES,
    selected_irt_models=SELECTED_IRT_MODELS  # Pass selected model list
)

'''
python /Users/bytedance/Desktop/QileZhang/llm/IRT/eval/Metric/modeling_0912.py 
'''