#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

"""
Mixed-metric Continuous IRT Model for Multi-metric Benchmark Evaluation
- Input: Multiple CSV files, each for one metric (rows=questions, cols=models)
- Output: Data splits, MCMC traces, sample-level predictions, MSE metrics, plots
- Model: Mixed-metric continuous IRT with global ability + metric-specific offsets + LKJ covariance prior
"""

import os
import time
import pickle
import warnings
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
from scipy.special import logit, expit

# Global Config
plt.rcParams["font.family"] = ["Arial", "Helvetica"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
warnings.filterwarnings("ignore")

# =========================
# 1. Robust CSV Reading & Preprocessing for Multiple Metrics
# =========================
def robust_read_csv(path: str) -> pd.DataFrame:
    encodings = ["utf-8", "utf-8-sig", "gbk", "latin1"]
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    raise ValueError(f"Failed to read CSV: {str(last_err)}")

def preprocess_single_metric(
    df: pd.DataFrame,
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess one metric's data:
    - Filter questions with too many zeros
    - Clip values to (epsilon, 1-epsilon)
    - Transpose to (models x questions)
    """
    work = df.copy()
    if question_id_col is not None:
        work = work.set_index(question_id_col)
    original_question_names = work.index.tolist()
    zero_ratio = (work == 0).sum(axis=1) / max(1, work.shape[1])
    work = work[zero_ratio <= zero_row_ratio_threshold].copy()
    filtered_question_names = work.index.tolist()
    print(f"Filtered questions: {len(original_question_names)} → {len(filtered_question_names)}")

    work = work.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    work = work.replace(0.0, epsilon).clip(lower=epsilon, upper=1 - epsilon)
    work_t = work.T
    model_names = work_t.index.tolist()
    Y = work_t.values
    return work_t, Y, model_names, filtered_question_names

def preprocess_multiple_metrics(
    csv_paths: List[str],
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[np.ndarray, List[str], List[str], List[str]]:
    """
    Read and preprocess multiple metric CSV files.
    Return:
      - Y: np.ndarray shape (N_models, J_items, M_metrics)
      - model_names: list of model names (intersection across metrics)
      - question_names: list of question names (intersection across metrics)
      - metric_names: list of metric names (from file names)
    """
    metric_dfs = []
    metric_names = []

    for path in csv_paths:
        metric_name = os.path.splitext(os.path.basename(path))[0]
        metric_names.append(metric_name)
        print(f"\nProcessing metric: {metric_name}")
        df = robust_read_csv(path)
        df_t, _, _, _ = preprocess_single_metric(
            df, question_id_col, zero_row_ratio_threshold, epsilon
        )
        metric_dfs.append(df_t)

    # Take intersection of models across all metrics (row indices)
    model_sets = [set(df.index) for df in metric_dfs]
    common_models = sorted(set.intersection(*model_sets))
    print(f"Model intersection size: {len(common_models)}")

    # Take intersection of questions across all metrics (column indices)
    question_sets = [set(df.columns) for df in metric_dfs]
    common_questions = sorted(set.intersection(*question_sets))
    print(f"Question intersection size: {len(common_questions)}")

    # Reindex all metric DataFrames to ensure consistent row/column order
    Ys_aligned = []
    for df in metric_dfs:
        df_aligned = df.loc[common_models, common_questions]
        Ys_aligned.append(df_aligned.values.T)  # Transpose to (models x questions)

    # Convert to numpy array with shape (M_metrics, N_models, J_questions)
    Ys_aligned = np.array(Ys_aligned)
    print(f"Ys_aligned shape (M, N, J): {Ys_aligned.shape}")

    # Transpose to (N_models, J_questions, M_metrics)
    Y = np.transpose(Ys_aligned, (1, 2, 0))

    return Y,common_questions , common_models, metric_names


# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set (5%) and remaining pool (95%)—only run once.
    Here Y_shape = (N_models, J_items, M_metrics)
    We split on (N_models x J_items) pairs, same test_mask across metrics.
    Returns: test_mask, remaining_mask with shape (N, J)
    """
    N, J, M = Y_shape
    rng = np.random.default_rng(test_seed)
    test_mask = rng.random((N, J)) < test_ratio
    remaining_mask = ~test_mask
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask

def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool (exclude fixed test set).
    Input masks shape: (N_models, J_items)
    Returns train_mask shape (N, J)
    """
    rng = np.random.default_rng(rep_seed)
    train_mask = np.zeros_like(remaining_mask, dtype=bool)

    remaining_rows, remaining_cols = np.where(remaining_mask)
    n_remaining = len(remaining_rows)
    n_train = int(n_remaining * train_subset_ratio)
    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True
    return train_mask

def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    N = len(model_names)
    J = len(question_names)
    for rep_idx, train_mask in enumerate(train_masks):
        if train_mask.shape != (N, J):
            raise ValueError(f"Training mask shape {train_mask.shape} does not match data dimensions ({N},{J})!")

    for rep_idx, train_mask in enumerate(train_masks):
        train_rows, train_cols = np.where(train_mask)
        test_rows, test_cols = np.where(test_mask)
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.3f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"Data split records saved (training ratio: {train_subset_ratio:.3f})")

# =========================
# 3. Prediction Methods (Baselines)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate global mean prediction for each metric
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        # Calculate global mean separately for each metric
        global_mean_m = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        pred[:, :, m] = global_mean_m
    
    return pred

def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate model mean prediction for each metric (for each model's performance on each metric)
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        row_means = np.zeros(N)
        global_mean = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        for i in range(N):
            model_train_mask = train_mask[i, :]
            if model_train_mask.any():
                row_means[i] = Y[i, model_train_mask, m].mean()
            else:
                row_means[i] = global_mean
        # Broadcast to shape (N,J,M)
        pred[:, :, m] = np.tile(row_means[:, None], (1, J))
    
    return pred

def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate question mean prediction for each metric (for each question's performance on each metric)
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        col_means = np.zeros(J)
        global_mean = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        for j in range(J):
            question_train_mask = train_mask[:, j]
            if question_train_mask.any():
                col_means[j] = Y[question_train_mask, j, m].mean()
            else:
                col_means[j] = global_mean
        # Broadcast to shape (N,J,M)
        pred[:, :, m] = np.tile(col_means[None, :], (N, 1))
    
    return pred

# =========================
# 4. Mixed-metric Continuous IRT Model Implementation
# =========================
import time
import os
import numpy as np
import pymc as pm
import arviz as az
from scipy.special import logit, expit
from typing import List, Tuple, Optional, Dict

def fit_irt_1pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None,
    seed: int = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    print('******************************************')
    print('******************************************')
    print('******************************************')
    print('******************************************')
    print('******************************************')

    """
    1PL IRT Model: P(Y_ij) = expit(theta_i - b_j)
    Theta: Model ability; b: Question difficulty
    Returns: Full MCMC trace (for post-hoc analysis), posterior mean parameters
    """
    # Take the first metric (assumed to be bertscore_F1) for single-metric modeling
    if Y.ndim == 3:
        Y_single = Y[:, :, 0]  # Take the first metric
    else:
        Y_single = Y
    
    N, J = Y_single.shape
    # Clip F1 scores (avoid logit overflow, consistent with preprocessing)
    Y_clip = np.clip(Y_single, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)  # Convert to logit scale (consistent with IRT model's normality assumption)

    # Automatically adapt to CPU cores (avoid overusing resources)
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    # Build 1PL model (based on PyMC Bayesian framework)
    with pm.Model() as irt_1pl:
        # 1. Define parameter priors (based on common IRT domain assumptions)
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # Model ability: Normal distribution (mean 0, std 1)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # Question difficulty: Normal distribution (mean 0, std 1)

        # 2. Linear predictor (expected value on logit scale)
        mu = theta[:, None] - b[None, :]  # Broadcasting operation: (N,1) - (1,J) → (N,J)

        # 3. Likelihood function (fit only on training set data)
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC sampling (core step: estimate parameter posterior distribution)
        start_time = time.time()
        trace = pm.sample(
            draws=draws,          # Samples per chain (1000 effective samples to ensure stable estimation)
            tune=tune,            # Warmup iterations (1000 iterations for sampler to converge to target distribution)
            chains=chains,        # Parallel chains (4 chains for convergence checking)
            cores=cores,          # Parallel cores (match chain count to speed up sampling)
            target_accept=0.95,   # Target acceptance rate (high acceptance rate improves sampling efficiency, reduces autocorrelation)
            progressbar=True,    # Disable progress bar (reduce output interference)
            random_seed=seed,     # Set random seed for reproducibility
            return_inferencedata=True  # Return trace in ArviZ format (for subsequent analysis)
        )
        print(f"1PL IRT sampling time: {time.time()-start_time:.1f}s")

    # Convergence check (key: r_hat < 1.01 indicates sampling convergence, reliable results)
    r_hat = az.summary(trace, var_names=["theta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 1PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01 (potential convergence issue)")

    # Extract parameter posterior means (for generating predictions)
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def predict_from_irt_1pl(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from 1PL IRT model parameters
    """
    theta = params["theta"]
    b = params["b"]
    N = len(theta)
    J = len(b)
    
    # Calculate linear predictor
    mu = theta[:, None] - b[None, :]  # 广播运算：(N,1) - (1,J) → (N,J)
    
    # Convert back to original scale
    pred = expit(mu)
    
    return pred

def predict_from_mixed_metric_irt(params: Dict[str, np.ndarray], J: int) -> np.ndarray:
    """
    Generate predictions from simplified mixed-metric IRT posterior means.
    Output shape: (N_models, J_items, M_metrics), values in (0,1)
    """
    psi = params["psi"][:, None, None]       # (N,1,1)
    zeta = params["zeta"][:, None, :]       # (N,1,M)
    b = params["b"][None, :, None]           # (1,J,1)

    # Simplified model: Remove discrimination parameter, directly calculate linear prediction
    mu = (psi + zeta) - b                # (N,J,M)
    pred = expit(mu)                         # Convert back to probability space (0,1)
    return pred


def build_effective_multi_metric_model(N, M, J, Y, train_mask, draws=100, tune=100):
    """Build an effective multi-metric IRT model, specifically optimizing target metric prediction"""
    # Data processing
    logit_Y = logit(Y)
    
    # Expand training mask to 3D
    train_mask_3d = np.tile(train_mask[..., None], (1, 1, M))
    
    with pm.Model() as model:
        # Model ability parameters: global ability psi + metric-specific ability zeta
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)  # Global ability
        zeta = pm.Normal("zeta", mu=0, sigma=0.5, shape=(N, M))  # Metric-specific ability
        
        # Question parameters
        a = pm.Normal("a", mu=1.0, sigma=0.5, shape=J)  # Discrimination
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)  # Difficulty
        
        # Calculate ability parameters (for each metric)
        # theta_{ijm} = psi_i + zeta_{im}
        theta = psi[:, None, None] + zeta[:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> broadcast to (N, J, M)
        
        # Calculate predictions
        # mu_{ijm} = a_j * theta_{ijm} - b_j
        mu = a[None, :, None] * theta - b[None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
        
        # Use only training data
        logit_Y_obs = logit_Y[train_mask_3d]
        mu_obs = mu[train_mask_3d]
        
        # Likelihood function - Use smaller noise for improved accuracy
        pm.Normal("y_obs", mu=mu_obs, sigma=0.8, observed=logit_Y_obs)
        
        # MCMC sampling
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=4,
            cores=4,
            target_accept=0.95,
            return_inferencedata=True,
            compute_convergence_checks=False  # Simplified convergence check
        )
    
    return trace, model


def predict_from_effective_multi_metric_model(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from effective multi-metric IRT model parameters, specifically optimizing target metric prediction
    """
    # theta_{ijm} = psi_i + zeta_{im}
    theta = params["psi"][:, None, None] + params["zeta"][:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> 广播到(N, J, M)
    
    # mu_{ijm} = a_j * theta_{ijm} - b_j
    mu = params["a"][None, :, None] * theta - params["b"][None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
    
    # Convert logit predictions back to original scale
    pred = expit(mu)
    
    # Return predictions for target metric (first metric)
    return pred[:, :, 0]  # Return array with shape (N, J)


def predict_from_target_optimized_multi_metric_model(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from target-optimized multi-metric IRT model parameters, specifically optimizing target metric prediction performance
    """
    # theta_{ijm} = psi_i + zeta_{im}
    theta = params["psi"][:, None, None] + params["zeta"][:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> 广播到(N, J, M)
    
    # mu_{ijm} = a_j * theta_{ijm} - b_j
    mu = params["a"][None, :, None] * theta - params["b"][None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
    
    # Convert logit predictions back to original scale
    pred = expit(mu)
    
    # Return predictions for target metric (first metric)
    return pred[:, :, 0]  # Return array with shape (N, J)


def build_simple_multi_metric_model(N, M, J, Y, train_mask, draws=100, tune=100):
    """Build a simplified multi-metric IRT model"""
    # Data processing
    logit_Y = logit(Y)
    
    # Expand training mask to 3D
    train_mask_3d = np.tile(train_mask[..., None], (1, 1, M))
    
    with pm.Model() as model:
        # Model ability parameters: global ability psi + metric-specific ability zeta
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)  # Global ability
        zeta = pm.Normal("zeta", mu=0, sigma=0.5, shape=(N, M))  # Metric-specific ability
        
        # Question parameters
        a = pm.Normal("a", mu=1.0, sigma=0.5, shape=J)  # Discrimination
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)  # Difficulty
        
        # Calculate ability parameters (for each metric)
        # theta_{ijm} = psi_i + zeta_{im}
        theta = psi[:, None, None] + zeta[:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> broadcast to (N, J, M)
        
        # Calculate predictions
        # mu_{ijm} = a_j * theta_{ijm} - b_j
        mu = a[None, :, None] * theta - b[None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
        
        # Use only training data
        logit_Y_obs = logit_Y[train_mask_3d]
        mu_obs = mu[train_mask_3d]
        
        # Likelihood function
        pm.Normal("y_obs", mu=mu_obs, sigma=1.0, observed=logit_Y_obs)
        
        # MCMC sampling
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=12,
            cores=12,
            target_accept=0.95,
            return_inferencedata=True,
            compute_convergence_checks=False  # Simplified convergence check
        )
    
    return trace, model


def predict_from_simple_multi_metric_model(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from simplified multi-metric IRT model parameters, only returning predictions for the target metric (first metric)
    """
    # theta_{ijm} = psi_i + zeta_{im}
    theta = params["psi"][:, None, None] + params["zeta"][:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> 广播到(N, J, M)
    
    # mu_{ijm} = a_j * theta_{ijm} - b_j
    mu = params["a"][None, :, None] * theta - params["b"][None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
    
    # Convert logit predictions back to original scale
    pred = expit(mu)
    
    # Only return predictions for target metric (first metric)
    return pred[:, :, 0]  # Return array with shape (N, J)


def predict_from_target_optimized_multi_metric_model(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from target-optimized multi-metric IRT model parameters, specifically optimizing target metric prediction performance
    """
    # theta_{ijm} = psi_i + zeta_{im}
    theta = params["psi"][:, None, None] + params["zeta"][:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> 广播到(N, J, M)
    
    # mu_{ijm} = a_j * theta_{ijm} - b_j
    mu = params["a"][None, :, None] * theta - params["b"][None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
    
    # Convert logit predictions back to original scale
    pred = expit(mu)
    
    # Return predictions for target metric (first metric)
    return pred[:, :, 0]  # Return array with shape (N, J)

# =========================
# 5. MSE Calculation & Result Saving for Mixed-metric
# =========================
def calculate_mse_multi_metric(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    test_mask: np.ndarray
) -> float:
    """
    Calculate MSE on test samples, averaged over all metrics.
    y_true, y_pred shape: (N, J, M)
    test_mask shape: (N, J) boolean
    """
    # Expand test_mask to (N,J,M)
    test_mask_expanded = test_mask[:, :, None].repeat(y_true.shape[2], axis=2)
    y_true_test = y_true[test_mask_expanded]
    y_pred_test = y_pred[test_mask_expanded]
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)

def calculate_mse_single_metric(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    test_mask: np.ndarray,
    metric_index: int = 0
) -> float:
    """
    Calculate MSE on test samples for a single metric.
    y_true, y_pred shape: (N, J, M)
    test_mask shape: (N, J) boolean
    metric_index: index of the metric to calculate MSE for (default: 0 for bertscore_F1)
    """
    # Take data for specified metric only
    y_true_single = y_true[:, :, metric_index]
    y_pred_single = y_pred[:, :, metric_index]
    
    # Apply test mask
    y_true_test = y_true_single[test_mask]
    y_pred_test = y_pred_single[test_mask]
    
    # Calculate MSE
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)

def save_sample_level_predictions_multi_metric(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    metric_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_mixed_irt: Optional[np.ndarray] = None,
    pred_simple_irt: Optional[np.ndarray] = None,
    pred_1pl_irt: Optional[np.ndarray] = None
):
    """
    Save sample-level test set predictions for multi-metric data.
    Each row = one (model, question, metric) triple.
    """
    test_rows, test_cols = np.where(test_mask)
    n_samples = len(test_rows)
    M = len(metric_names)

    # Flatten true values and predictions into 1D lists
    true_values = []
    global_preds = []
    row_preds = []
    col_preds = []
    mixed_irt_preds = []
    simple_irt_preds = []
    irt_1pl_preds = []

    for idx in range(n_samples):
        i = test_rows[idx]
        j = test_cols[idx]
        for m in range(M):
            true_values.append(Y[i, j, m])
            global_preds.append(pred_global[i, j, m])
            row_preds.append(pred_row[i, j, m])
            col_preds.append(pred_col[i, j, m])
            if pred_mixed_irt is not None:
                mixed_irt_preds.append(pred_mixed_irt[i, j, m])
            else:
                mixed_irt_preds.append(np.nan)
            if pred_simple_irt is not None:
                simple_irt_preds.append(pred_simple_irt[i, j, m])
            else:
                simple_irt_preds.append(np.nan)
            if pred_1pl_irt is not None and m == 0:  # Only has values for the first metric (bertscore_F1)
                irt_1pl_preds.append(pred_1pl_irt[i, j, m])
            elif pred_1pl_irt is not None:
                irt_1pl_preds.append(np.nan)
            else:
                irt_1pl_preds.append(np.nan)

    # Build DataFrame
    pred_df = pd.DataFrame({
        "model_idx": np.repeat(test_rows, M),
        "model_name": [model_names[r] for r in np.repeat(test_rows, M)],
        "question_idx": np.repeat(test_cols, M),
        "question_name": [question_names[c] for c in np.repeat(test_cols, M)],
        "metric_idx": np.tile(np.arange(M), n_samples),
        "metric_name": np.tile(metric_names, n_samples),
        "true_value": true_values,
        "global_mean_pred": global_preds,
        "model_mean_pred": row_preds,
        "question_mean_pred": col_preds,
        "mixed_metric_irt_pred": mixed_irt_preds,
        "simple_metric_irt_pred": simple_irt_preds,
        "1pl_irt_pred": irt_1pl_preds,
        "train_ratio": train_subset_ratio,
        "repetition": rep + 1
    })

    os.makedirs(pred_dir, exist_ok=True)
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"✅ Sample-level predictions saved (multi-metric): {save_path}")
    print(f"   Total {len(pred_df)} test set records")

def save_irt_trace_multi_metric(
    trace_dir: str,
    trace: az.InferenceData,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full mixed-metric IRT MCMC trace to NetCDF file.
    """
    save_path = os.path.join(
        trace_dir,
        f"mixed_metric_irt_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    az.to_netcdf(trace, save_path)
    print(f"Mixed-metric IRT trace saved: {save_path}")

# =========================
# 6. Experiment Main Pipeline for Mixed-metric
# =========================
def run_experiment_mixed_metric(
    input_csv_paths: List[str],
    output_root_dir: str,
    question_id_col: Optional[str or int] = 0,
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.5, 1.0],
    rep_count: int = 3,
    irt_draws: int = 1000,
    irt_tune: int = 1000,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
):
    """
    Mixed-metric continuous IRT experiment main pipeline.
    """

    # 1. Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
    }
    for d in output_dirs.values():
        os.makedirs(d, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    # 2. Read and preprocess multiple metric data
    print("\n=== Step 1/5: Read & Preprocess Multiple Metrics ===")
    Y, model_names, question_names, metric_names = preprocess_multiple_metrics(
        input_csv_paths,
        question_id_col=question_id_col,
        zero_row_ratio_threshold=0.3,
        epsilon=1e-6
    )
    N, J, M = Y.shape
    print(f"Preprocessed data shape: {N} models × {J} questions × {M} metrics")

    # 3. Split fixed test set
    print("\n=== Step 2/5: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J, M),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    # 4. Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios},
        "Mixed-metric IRT": {r: [] for r in train_ratios},
        "Effective Multi-metric IRT": {r: [] for r in train_ratios},
        "Simple Multi-metric IRT (Aux)": {r: [] for r in train_ratios},
        "1PL IRT (bertscore_F1)": {r: [] for r in train_ratios}
    }

    # 5. Iterate through training ratios and repetitions
    for train_ratio in train_ratios:
        print(f"\n=== Step 3/5: Train Ratio = {train_ratio:.3f} ===")
        train_masks_rep = []

        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep

            # Sample training set
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            # Baseline predictions
            print("Running baseline methods...")
            pred_global = predict_global_mean(Y, train_mask)
            mse_global = calculate_mse_single_metric(Y, pred_global, test_mask, 0)
            pred_row = predict_row_mean(Y, train_mask)
            mse_row = calculate_mse_single_metric(Y, pred_row, test_mask, 0)
            pred_col = predict_col_mean(Y, train_mask)
            mse_col = calculate_mse_single_metric(Y, pred_col, test_mask, 0)

            # Train mixed-metric IRT model
            trace_file_path = os.path.join(
                output_dirs["irt_trace"],
                f"mixed_metric_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_file_path):
                print(f"Found existing trace file, loading from {trace_file_path} ...")
                trace, params = load_trace_and_params(trace_file_path)
            else:
                print("Running mixed-metric continuous IRT model...")
                trace, params = fit_mixed_metric_irt(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores,
                    seed=rep_seed
                )
                save_irt_trace_multi_metric(output_dirs["irt_trace"], trace, train_ratio, rep)

            pred_mixed_irt = predict_from_mixed_metric_irt(params)
            mse_mixed_irt = calculate_mse_single_metric(Y, pred_mixed_irt, test_mask, 0)
            print(f"Mixed-metric IRT MSE: {mse_mixed_irt:.6f}")

            # Train simplified multi-metric IRT model
            trace_file_path_simple = os.path.join(
                output_dirs["irt_trace"],
                f"simple_multimetric_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_file_path_simple):
                print(f"Found existing simple multimetric trace file, loading from {trace_file_path_simple} ...")
                trace_simple = az.from_netcdf(trace_file_path_simple)
                # Extract parameters from trace
                params_simple = {
                    "psi": trace_simple.posterior["psi"].mean(dim=["chain", "draw"]).values,
                    "zeta": trace_simple.posterior["zeta"].mean(dim=["chain", "draw"]).values,
                    "a": trace_simple.posterior["a"].mean(dim=["chain", "draw"]).values,
                    "b": trace_simple.posterior["b"].mean(dim=["chain", "draw"]).values
                }
            else:
                print("Running simple multi-metric IRT model...")
                trace_simple, model_simple = build_simple_multi_metric_model(
                    N=N, M=M, J=J,
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune
                )
                # Save trace file
                trace_simple.to_netcdf(trace_file_path_simple)
                print(f"Simple multi-metric IRT trace saved: {trace_file_path_simple}")
                # Extract parameter posterior means
                params_simple = {
                    "psi": trace_simple.posterior["psi"].mean(dim=["chain", "draw"]).values,
                    "zeta": trace_simple.posterior["zeta"].mean(dim=["chain", "draw"]).values,
                    "a": trace_simple.posterior["a"].mean(dim=["chain", "draw"]).values,
                    "b": trace_simple.posterior["b"].mean(dim=["chain", "draw"]).values
                }

            pred_simple_irt = predict_from_simple_multi_metric_model(params_simple)
            # Expand 2D predictions to 3D, only has values for the first metric (bertscore_F1)
            pred_simple_irt_3d = np.zeros_like(Y)
            pred_simple_irt_3d[:, :, 0] = pred_simple_irt
            mse_simple_irt = calculate_mse_single_metric(Y, pred_simple_irt_3d, test_mask, 0)
            print(f"Simple Multi-metric IRT MSE: {mse_simple_irt:.6f}")

            # Train effective multi-metric IRT model
            trace_file_path_effective = os.path.join(
                output_dirs["irt_trace"],
                f"effective_multimetric_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_file_path_effective):
                print(f"Found existing effective multimetric trace file, loading from {trace_file_path_effective} ...")
                trace_effective = az.from_netcdf(trace_file_path_effective)
                # Extract parameters from trace
                params_effective = {
                    "psi": trace_effective.posterior["psi"].mean(dim=["chain", "draw"]).values,
                    "zeta": trace_effective.posterior["zeta"].mean(dim=["chain", "draw"]).values,
                    "a": trace_effective.posterior["a"].mean(dim=["chain", "draw"]).values,
                    "b": trace_effective.posterior["b"].mean(dim=["chain", "draw"]).values
                }
            else:
                print("Running effective multi-metric IRT model...")
                trace_effective, model_effective = build_effective_multi_metric_model(
                    N=N, M=M, J=J,
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune
                )
                # Save trace file
                trace_effective.to_netcdf(trace_file_path_effective)
                print(f"Effective multi-metric IRT trace saved: {trace_file_path_effective}")
                # Extract parameter posterior means
                params_effective = {
                    "psi": trace_effective.posterior["psi"].mean(dim=["chain", "draw"]).values,
                    "zeta": trace_effective.posterior["zeta"].mean(dim=["chain", "draw"]).values,
                    "a": trace_effective.posterior["a"].mean(dim=["chain", "draw"]).values,
                    "b": trace_effective.posterior["b"].mean(dim=["chain", "draw"]).values
                }

            pred_effective_irt = predict_from_effective_multi_metric_model(params_effective)
            # Expand 2D predictions to 3D, only has values for the first metric (bertscore_F1)
            pred_effective_irt_3d = np.zeros_like(Y)
            pred_effective_irt_3d[:, :, 0] = pred_effective_irt
            mse_effective_irt = calculate_mse_single_metric(Y, pred_effective_irt_3d, test_mask, 0)
            print(f"Effective Multi-metric IRT MSE: {mse_effective_irt:.6f}")

            # Train enhanced multi-metric IRT model (using auxiliary metrics as priors)
            print("Running simple multi-metric IRT model (with auxiliary metrics as priors)...")
            # Prepare main metric and auxiliary metric data
            Y_main = Y[:, :, 0]  # Use the first metric (bertscore_F1) as the main metric
            Y_aux_list = [Y[:, :, i] for i in range(1, Y.shape[2])]  # Other metrics as auxiliary metrics
            
            # Set sampling parameters
            sample_kwargs = {
                "draws": irt_draws,
                "tune": irt_tune,
                "chains": irt_chains,
                "cores": irt_cores if irt_cores is not None else min(os.cpu_count() or 1, irt_chains)
            }
            
            trace_aux_file_path = os.path.join(
                output_dirs["irt_trace"],
                f"simple_multimetric_irt_aux_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            
            if os.path.exists(trace_aux_file_path):
                print(f"Found existing simple multimetric aux trace file, loading from {trace_aux_file_path} ...")
                trace_aux = az.from_netcdf(trace_aux_file_path)
                # Extract parameters from trace
                params_aux = {
                    "theta": trace_aux.posterior["theta"].mean(dim=["chain", "draw"]).values,
                    "b": trace_aux.posterior["b"].mean(dim=["chain", "draw"]).values
                }
                # Generate predictions
                theta_mean = params_aux["theta"].reshape(-1, 1)
                b_mean = params_aux["b"].reshape(1, -1)
                mu_pred = theta_mean - b_mean
                pred_aux_irt = expit(mu_pred)
            else:
                trace_aux, params_aux, pred_aux_irt = run_simple_multimetric_irt(
                    Y_main=Y_main,
                    Y_aux_list=Y_aux_list,
                    sample_ratio=1.0,  # Use all training data
                    output_dir=output_dirs["irt_trace"],
                    model_names=model_names,
                    key_names=question_names,
                    sample_kwargs=sample_kwargs,
                    seed=rep_seed,
                    train_mask=train_mask
                )
                # Save trace file
                trace_aux.to_netcdf(trace_aux_file_path)
                print(f"Simple multi-metric IRT (aux) trace saved: {trace_aux_file_path}")

            # Expand 2D predictions to 3D, only has values for the first metric (bertscore_F1)
            pred_aux_irt_3d = np.zeros_like(Y)
            pred_aux_irt_3d[:, :, 0] = pred_aux_irt
            mse_aux_irt = calculate_mse_single_metric(Y, pred_aux_irt_3d, test_mask, 0)
            print(f"Simple Multi-metric IRT (Aux) MSE: {mse_aux_irt:.6f}")

            # Train single-metric 1PL IRT model (using bertscore_F1 metric)
            print("Running 1PL IRT model (bertscore_F1)...")
            trace_1pl_file_path = os.path.join(
                output_dirs["irt_trace"],
                f"1pl_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_1pl_file_path):
                print(f"Found existing 1PL trace file, loading from {trace_1pl_file_path} ...")
                trace_1pl = az.from_netcdf(trace_1pl_file_path)
                # Extract parameters from trace
                params_1pl = {
                    "theta": trace_1pl.posterior["theta"].mean(dim=["chain", "draw"]).values,
                    "b": trace_1pl.posterior["b"].mean(dim=["chain", "draw"]).values
                }
            else:
                trace_1pl, params_1pl = fit_irt_1pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores,
                    seed=rep_seed
                )
                # Save 1PL IRT trace
                save_path_1pl = os.path.join(
                    output_dirs["irt_trace"],
                    f"1pl_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
                )
                trace_1pl.to_netcdf(save_path_1pl)
                print(f"1PL IRT trace saved: {save_path_1pl}")

            pred_1pl = predict_from_irt_1pl(params_1pl)
            # Expand 2D predictions to 3D, only has values for the first metric (bertscore_F1)
            pred_1pl_3d = np.zeros_like(Y)
            pred_1pl_3d[:, :, 0] = pred_1pl
            mse_1pl = calculate_mse_single_metric(Y, pred_1pl_3d, test_mask, 0)
            print(f"1PL IRT (bertscore_F1) MSE: {mse_1pl:.6f}")

            # Record MSE
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            mse_summary["Mixed-metric IRT"][train_ratio].append(mse_mixed_irt)
            mse_summary["Effective Multi-metric IRT"][train_ratio].append(mse_effective_irt)
            mse_summary["Simple Multi-metric IRT (Aux)"][train_ratio].append(mse_aux_irt)
            mse_summary["1PL IRT (bertscore_F1)"][train_ratio].append(mse_1pl)

            # Save sample-level predictions
            save_sample_level_predictions_multi_metric(
                pred_dir=output_dirs["predictions"],
                Y=Y,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                metric_names=metric_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_mixed_irt=pred_mixed_irt,
                pred_1pl_irt=pred_1pl_3d
            )

            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            print(f"Mixed-metric IRT: {mse_mixed_irt:.6f}")
            print(f"Effective Multi-metric IRT: {mse_effective_irt:.6f}")
            print(f"Simple Multi-metric IRT (Aux): {mse_aux_irt:.6f}")
            print(f"1PL IRT (bertscore_F1): {mse_1pl:.6f}")

        # Save data split records
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

    # 6. Save MSE summary and plot
    print("\n=== Step 4/5: Generate Final Results ===")

    # Save MSE summary CSV
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"
        rows.append(row)
    pd.DataFrame(rows).to_csv(mse_summary_path, index=False)
    print(f"MSE summary saved: {mse_summary_path}")

    # Plot MSE comparison
    def plot_mse_comparison(mse_dict, train_ratios, save_path):
        method_styles = {
            "Global Mean": ("blue", "solid"),
            "Model Mean": ("orange", "dashed"),
            "Question Mean": ("green", "dashdot"),
            "Mixed-metric IRT": ("red", "solid"),
            "Effective Multi-metric IRT": ("cyan", "solid"),
            "Simple Multi-metric IRT (Aux)": ("magenta", "solid"),
            "1PL IRT (bertscore_F1)": ("purple", "solid"),
        }
        plt.figure(figsize=(10, 6))
        for method, mse_list in mse_dict.items():
            color, linestyle = method_styles.get(method, ("black", "solid"))
            plt.plot(
                train_ratios,
                mse_list,
                label=method,
                color=color,
                linestyle=linestyle,
                linewidth=2,
                marker="o",
                markersize=4
            )
        plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
        plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
        plt.title("MSE vs. Training Data Ratio for Mixed-metric IRT and Baselines", fontsize=14, pad=20)
        plt.legend(loc="upper right", fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.xticks(train_ratios, [f"{r:.2f}" for r in train_ratios], fontsize=10)
        plt.yticks(fontsize=10)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"MSE plot saved: {save_path}")

    plot_mse_dict = {
        method: [np.mean(mse_summary[method][r]) for r in train_ratios]
        for method in mse_summary.keys()
    }
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)

    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")

# =========================
# 7. Example Usage
# =========================

def load_trace_and_params(trace_path: str) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    Load sampling results from NetCDF file and extract posterior mean parameters.
    """
    trace = az.from_netcdf(trace_path)
    params = {
        "psi": trace.posterior["psi"].mean(dim=["chain", "draw"]).values,
        "zeta": trace.posterior["zeta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values,
    }
    return trace, params

def run_simple_multimetric_irt(
    Y_main: np.ndarray,
    Y_aux_list: List[np.ndarray],
    sample_ratio: float,
    output_dir: str,
    model_names: list,
    key_names: list,
    sample_kwargs: dict,
    seed: int = 20240619,
    train_mask: np.ndarray | None = None,
):
    """
    Simplified multi-metric IRT modeling
    Use main metric for IRT modeling, auxiliary metrics as prior information
    """
    N, J = Y_main.shape
    M = len(Y_aux_list)  # Number of auxiliary metrics
    _ensure_dir(output_dir)
    print(f"\n[Simplified multi-metric IRT] Starting modeling: N={N}, J={J}, auxiliary metrics={M}, training ratio={sample_ratio}")

    # 1. Data preprocessing (logit transformation)
    eps = 1e-5
    Y_main_clip = np.clip(Y_main.astype(np.float64), eps, 1 - eps)
    logit_Y_main = logit(Y_main_clip).astype(np.float64)
    
    # Process auxiliary metrics
    logit_Y_aux_list = []
    for Y_aux in Y_aux_list:
        Y_aux_clip = np.clip(Y_aux.astype(np.float64), eps, 1 - eps)
        logit_Y_aux = logit(Y_aux_clip).astype(np.float64)
        logit_Y_aux_list.append(logit_Y_aux)
    
    if train_mask is None:
        # If no training mask is provided, generate automatically
        rng = np.random.default_rng(seed)
        train_mask = rng.random((N, J)) < sample_ratio
    
    # Ensure train_mask is boolean type
    train_mask = train_mask.astype(bool)
    
    logit_train_main = np.where(train_mask, logit_Y_main, np.nan).astype(np.float64)
    obs_mask = ~np.isnan(logit_train_main)
    
    print(f"[Preprocessing confirmation] Number of non-NaN training samples: {np.sum(obs_mask)}")

    # Check if there are enough training samples
    if np.sum(obs_mask) == 0:
        print("Warning: No training samples available!")
        # Return default predictions
        Y_pred_mean = np.full_like(Y_main, np.nanmean(Y_main))
        return None, {"theta": np.zeros(N), "b": np.zeros(J)}, Y_pred_mean

    # 2. Calculate prior information using auxiliary metrics
    # Calculate auxiliary metrics' means as priors
    aux_means = []
    for Y_aux in Y_aux_list:
        aux_means.append(np.nanmean(Y_aux))
    
    # 3. Build simplified IRT model (based on main metric)
    with pm.Model() as irt_model:
        # Model ability parameters (using auxiliary metrics' means as prior means)
        if aux_means:
            prior_mean = np.mean(aux_means)
            theta = pm.Normal("theta", mu=prior_mean, sigma=1.0, shape=N)
        else:
            theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)
        
        # Question difficulty parameters
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)
        
        # Linear predictor
        mu = theta[:, None] - b[None, :]
        
        # Observation model
        pm.Normal(
            "y_obs", mu=mu[obs_mask], sigma=1.0,
            observed=logit_train_main[obs_mask]
        )
        
        # 4. MCMC sampling
        print(f"[MCMC sampling] Parameters: {sample_kwargs}")
        trace = pm.sample(
            **sample_kwargs,
            random_seed=seed,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )

    # 5. Extract parameter posterior means
    params_mean = {
        "theta": trace.posterior["theta"].mean(("chain", "draw")).values,
        "b": trace.posterior["b"].mean(("chain", "draw")).values
    }
    
    # 6. Generate predictions
    theta_mean = params_mean["theta"].reshape(-1, 1)
    b_mean = params_mean["b"].reshape(1, -1)
    mu_pred = theta_mean - b_mean
    Y_pred_mean = expit(mu_pred)
    
    return trace, params_mean, Y_pred_mean

def fit_mixed_metric_irt(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None,
    seed: int = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    Simplified mixed-metric IRT model
    - Retain global ability parameter psi and metric-specific ability parameter zeta
    - At question level, only retain difficulty parameter b, no centering processing
    - Remove discrimination parameter a and noise parameter sigma
    Y shape: (N_models, J_items, M_metrics), values in (0,1)
    train_mask shape: (N_models, J_items), bool mask for training samples
    """
    
    N, J, M = Y.shape
    eps = 1e-6
    Y_clip = np.clip(Y, eps, 1 - eps)
    logit_Y = logit(Y_clip)
    
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)
    
    # Flatten data for PyMC processing
    # First construct all (i,j,m) triple indices
    idx_i, idx_j, idx_m = np.meshgrid(np.arange(N), np.arange(J), np.arange(M), indexing="ij")
    idx_i = idx_i.flatten()
    idx_j = idx_j.flatten()
    idx_m = idx_m.flatten()
    
    # Flatten observations and mask
    y_obs_flat = logit_Y.flatten()
    train_mask_flat = np.repeat(train_mask.flatten(), M)  # train_mask only for (N,J), repeat M times for metrics
    
    # Keep only training samples
    train_idx = np.where(train_mask_flat)[0]
    y_train = y_obs_flat[train_idx]
    i_train = idx_i[train_idx]
    j_train = idx_j[train_idx]
    m_train = idx_m[train_idx]
    
    with pm.Model() as model:
        # Parameter priors
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)              # Global ability
        # LKJ covariance matrix parameters
        # chol, corr, stds = pm.LKJCholeskyCov(
        #     "eta",
        #     n=M,
        #     eta=2,
        #     sd_dist=pm.HalfNormal.dist(2.5)
        # )
        # zeta = pm.MvNormal("zeta", mu=np.zeros(M), chol=chol, shape=(N, M))  # metric offset
        
        zeta = pm.Normal("zeta", mu=0, sigma=0.5, shape=(N, M))  # Metric-specific ability
        
        # Simplified model: Only retain question difficulty parameters, no centering
        b = pm.Normal("b", mu=0, sigma=1, shape=J)               # Question difficulty parameter
        
        # Calculate linear prediction for training samples (remove discrimination parameter a)
        mu_train = (psi[i_train] + zeta[i_train, m_train]) - b[j_train]
        
        # Simplified model: Use fixed noise parameter 1.0
        sigma_train = 1.0
        
        # Likelihood
        pm.Normal("obs", mu=mu_train, sigma=sigma_train, observed=y_train)
        
        # Sampling
        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            return_inferencedata=True,
            compute_convergence_checks=False,  # Simplified convergence check
            random_seed=seed
        )
        print(f"Mixed-metric IRT sampling time: {time.time() - start_time:.1f}s")
    
    # Convergence check
    r_hat = az.summary(trace, var_names=["psi", "zeta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ Mixed-metric IRT: {sum(r_hat > 1.01)} parameters with r_hat > 1.01 (potential convergence issue)")
    
    # Extract parameter posterior means
    params = {
        "psi": trace.posterior["psi"].mean(dim=["chain", "draw"]).values,
        "zeta": trace.posterior["zeta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values,
    }
    
    return trace, params

def _ensure_dir(path: str):
    """Ensure directory exists"""
    os.makedirs(path, exist_ok=True)


if __name__ == "__main__":
    # Example of multiple metric file paths
    INPUT_CSV_PATHS = [
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__bertscore_F1.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__bertscore_P.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__bertscore_R.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__bleu.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__meteor.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__rouge1.csv",
        "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/continuous_processed/response_matrix__rougeL.csv",
    ]
    OUTPUT_ROOT_DIR = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/mix_metric_pack"
    QUESTION_ID_COL = 'row_index'  # If no question ID column, change to None
    TEST_RATIO = 0.10
    TRAIN_RATIOS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    # TRAIN_RATIOS = [1.0]
    REP_COUNT = 1
    IRT_DRAWS = 10
    IRT_TUNE = 10
    IRT_CHAINS = 12
    IRT_CORES = 12

    run_experiment_mixed_metric(
        input_csv_paths=INPUT_CSV_PATHS,
        output_root_dir=OUTPUT_ROOT_DIR,
        question_id_col=QUESTION_ID_COL,
        test_ratio=TEST_RATIO,
        test_seed=42,
        train_ratios=TRAIN_RATIOS,
        rep_count=REP_COUNT,
        irt_draws=IRT_DRAWS,
        irt_tune=IRT_TUNE,
        irt_chains=IRT_CHAINS,
        irt_cores=IRT_CORES
    )

'''
nohup python /Users/bytedance/Downloads/BIG-LEG_again/all_mix_experiment/mix_metric.py > /Users/bytedance/Downloads/BIG-LEG_again/all_mix_experiment/mix_metric.log 2>&1 &
'''

