#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

"""
Mixed-metric Continuous IRT Model for Multi-metric Benchmark Evaluation
- Input: Multiple CSV files, each for one metric (rows=questions, cols=models)
- Output: Data splits, MCMC traces, sample-level predictions, MSE metrics, plots
- Model: Mixed-metric continuous IRT with global ability + metric-specific offsets + LKJ covariance prior
"""

import os
import time
import pickle
import warnings
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
from scipy.special import logit, expit

# Global Config
plt.rcParams["font.family"] = ["Arial", "Helvetica"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
warnings.filterwarnings("ignore")

# =========================
# 1. Robust CSV Reading & Preprocessing for Multiple Metrics
# =========================
def robust_read_csv(path: str) -> pd.DataFrame:
    encodings = ["utf-8", "utf-8-sig", "gbk", "latin1"]
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    raise ValueError(f"Failed to read CSV: {str(last_err)}")

def preprocess_single_metric(
    df: pd.DataFrame,
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess one metric's data:
    - Filter questions with too many zeros
    - Clip values to (epsilon, 1-epsilon)
    - Transpose to (models x questions)
    """
    work = df.copy()
    if question_id_col is not None:
        work = work.set_index(question_id_col)
    original_question_names = work.index.tolist()
    zero_ratio = (work == 0).sum(axis=1) / max(1, work.shape[1])
    work = work[zero_ratio <= zero_row_ratio_threshold].copy()
    filtered_question_names = work.index.tolist()
    print(f"Filtered questions: {len(original_question_names)} → {len(filtered_question_names)}")

    work = work.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    work = work.replace(0.0, epsilon).clip(lower=epsilon, upper=1 - epsilon)
    work_t = work.T
    model_names = work_t.index.tolist()
    Y = work_t.values
    return work_t, Y, model_names, filtered_question_names

def preprocess_multiple_metrics(
    csv_paths: List[str],
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[np.ndarray, List[str], List[str], List[str]]:
    """
    Read and preprocess multiple metric CSV files.
    Return:
      - Y: np.ndarray shape (N_models, J_items, M_metrics)
      - model_names: list of model names (intersection across metrics)
      - question_names: list of question names (intersection across metrics)
      - metric_names: list of metric names (from file names)
    """
    metric_dfs = []
    metric_names = []

    for path in csv_paths:
        metric_name = os.path.splitext(os.path.basename(path))[0]
        metric_names.append(metric_name)
        print(f"\nProcessing metric: {metric_name}")
        df = robust_read_csv(path)
        df_t, _, _, _ = preprocess_single_metric(
            df, question_id_col, zero_row_ratio_threshold, epsilon
        )
        metric_dfs.append(df_t)

    # Get intersection of models (row indices) across all metrics
    model_sets = [set(df.index) for df in metric_dfs]
    common_models = sorted(set.intersection(*model_sets))
    print(f"Model intersection size: {len(common_models)}")

    # Get intersection of questions (column indices) across all metrics
    question_sets = [set(df.columns) for df in metric_dfs]
    common_questions = sorted(set.intersection(*question_sets))
    print(f"Question intersection size: {len(common_questions)}")

    # Reindex all metric DataFrames to ensure consistent row/column order
    Ys_aligned = []
    for df in metric_dfs:
        df_aligned = df.loc[common_models, common_questions]
        Ys_aligned.append(df_aligned.values.T)  # Transpose to (models x questions)

    # Convert to numpy array with shape (M_metrics, N_models, J_questions)
    Ys_aligned = np.array(Ys_aligned)
    print(f"Ys_aligned shape (M, N, J): {Ys_aligned.shape}")

    # Transpose to (N_models, J_questions, M_metrics)
    Y = np.transpose(Ys_aligned, (1, 2, 0))

    return Y,common_questions , common_models, metric_names


# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set (5%) and remaining pool (95%)—only run once.
    Here Y_shape = (N_models, J_items, M_metrics)
    We split on (N_models x J_items) pairs, same test_mask across metrics.
    Returns: test_mask, remaining_mask with shape (N, J)
    """
    N, J, M = Y_shape
    rng = np.random.default_rng(test_seed)
    test_mask = rng.random((N, J)) < test_ratio
    remaining_mask = ~test_mask
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask

def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool (exclude fixed test set).
    Input masks shape: (N_models, J_items)
    Returns train_mask shape (N, J)
    """
    rng = np.random.default_rng(rep_seed)
    train_mask = np.zeros_like(remaining_mask, dtype=bool)

    remaining_rows, remaining_cols = np.where(remaining_mask)
    n_remaining = len(remaining_rows)
    n_train = int(n_remaining * train_subset_ratio)
    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True
    return train_mask

def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    N = len(model_names)
    J = len(question_names)
    for rep_idx, train_mask in enumerate(train_masks):
        if train_mask.shape != (N, J):
            raise ValueError(f"Training mask shape {train_mask.shape} does not match data dimensions ({N},{J})!")

    for rep_idx, train_mask in enumerate(train_masks):
        train_rows, train_cols = np.where(train_mask)
        test_rows, test_cols = np.where(test_mask)
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.3f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"Data split records saved (training ratio: {train_subset_ratio:.3f})")

# =========================
# 3. Prediction Methods (Baselines)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate global mean prediction for each metric
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        # Calculate global mean separately for each metric
        global_mean_m = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        pred[:, :, m] = global_mean_m
    
    return pred

def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate model mean prediction for each metric (model performance on each metric)
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        row_means = np.zeros(N)
        global_mean = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        for i in range(N):
            model_train_mask = train_mask[i, :]
            if model_train_mask.any():
                row_means[i] = Y[i, model_train_mask, m].mean()
            else:
                row_means[i] = global_mean
        # Broadcast to shape (N,J,M)
        pred[:, :, m] = np.tile(row_means[:, None], (1, J))
    
    return pred

def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Calculate question mean prediction for each metric (question performance on each metric)
    """
    N, J, M = Y.shape
    pred = np.zeros_like(Y)
    
    for m in range(M):
        col_means = np.zeros(J)
        global_mean = Y[:, :, m][train_mask].mean() if train_mask.any() else 0.5
        for j in range(J):
            question_train_mask = train_mask[:, j]
            if question_train_mask.any():
                col_means[j] = Y[question_train_mask, j, m].mean()
            else:
                col_means[j] = global_mean
        # Broadcast to shape (N,J,M)
        pred[:, :, m] = np.tile(col_means[None, :], (N, 1))
    
    return pred

# =========================
# 4. Mixed-metric Continuous IRT Model Implementation
# =========================
import time
import os
import numpy as np
import pymc as pm
import arviz as az
from scipy.special import logit, expit
from typing import List, Tuple, Optional, Dict

def fit_irt_1pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None,
    seed: int = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    print('******************************************')
    print('******************************************')
    print('******************************************')
    print('******************************************')
    print('******************************************')

    """
    1PL IRT Model: P(Y_ij) = expit(theta_i - b_j)
    Theta: Model ability; b: Question difficulty
    Returns: Full MCMC trace (for post-hoc analysis), posterior mean parameters
    """
    # Take the first metric (assuming bertscore_F1) for single-metric modeling
    if Y.ndim == 3:
        Y_single = Y[:, :, 0]  # Take the first metric
    else:
        Y_single = Y
    
    N, J = Y_single.shape
    # Clip F1 scores (avoid logit overflow, consistent with preprocessing)
    Y_clip = np.clip(Y_single, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)  # Transform to logit scale (consistent with IRT model normality assumption)

    # Automatically adapt to CPU core count (avoid overusing resources)
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    # Build 1PL model (based on PyMC Bayesian framework)
    with pm.Model() as irt_1pl:
        # 1. Define parameter priors (based on common IRT domain assumptions)
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # Model ability: Normal distribution (mean 0, std 1)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # Question difficulty: Normal distribution (mean 0, std 1)

        # 2. Linear predictor (expected value on logit scale)
        mu = theta[:, None] - b[None, :]  # Broadcasting: (N,1) - (1,J) → (N,J)

        # 3. Likelihood function (fit only on training data)
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC sampling (core step: estimate parameter posterior distribution)
        start_time = time.time()
        trace = pm.sample(
            draws=draws,          # Samples per chain (1000 effective samples to ensure stable estimation)
            tune=tune,            # Warmup iterations (1000 iterations to let sampler converge to target distribution)
            chains=chains,        # Parallel chains (4 chains for convergence checking)
            cores=cores,          # Parallel cores (match chain count to accelerate sampling)
            target_accept=0.95,   # Target acceptance rate (high acceptance rate improves sampling efficiency, reduces autocorrelation)
            progressbar=True,    # Disable progress bar (reduce output interference)
            random_seed=seed,     # Set random seed for reproducibility
            return_inferencedata=True  # Return ArviZ format trace (for subsequent analysis)
        )
        print(f"1PL IRT sampling time: {time.time()-start_time:.1f}s")

    # Convergence check (key: r_hat < 1.01 indicates sampling convergence, reliable results)
    r_hat = az.summary(trace, var_names=["theta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 1PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01 (potential convergence issue)")

    # Extract parameter posterior means (for generating predictions)
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def predict_from_irt_1pl(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from 1PL IRT model parameters
    """
    theta = params["theta"]
    b = params["b"]
    N = len(theta)
    J = len(b)
    
    # Calculate linear predictor
    mu = theta[:, None] - b[None, :]  # Broadcasting: (N,1) - (1,J) → (N,J)
    
    # Transform back to original scale
    pred = expit(mu)
    
    return pred

def predict_from_mixed_metric_irt(params: Dict[str, np.ndarray]) -> np.ndarray:
    """
    Generate predictions from simplified mixed-metric IRT posterior means.
    Output shape: (N_models, J_items, M_metrics), values in (0,1)
    """
    psi = params["psi"][:, None, None]       # (N,1,1)
    zeta = params["zeta"][:, None, :]       # (N,1,M)
    b = params["b"][None, :, None]           # (1,J,1)

    # Simplified model: Remove discrimination parameter, directly calculate linear prediction
    mu = (psi + zeta) - b                # (N,J,M)
    pred = expit(mu)                         # Transform back to probability space (0,1)
    return pred

# =========================
# 5. MSE Calculation & Result Saving for Mixed-metric
# =========================
def calculate_mse_multi_metric(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    test_mask: np.ndarray
) -> float:
    """
    Calculate MSE on test samples, averaged over all metrics.
    y_true, y_pred shape: (N, J, M)
    test_mask shape: (N, J) boolean
    """
    # Expand test_mask to (N,J,M)
    test_mask_expanded = test_mask[:, :, None].repeat(y_true.shape[2], axis=2)
    y_true_test = y_true[test_mask_expanded]
    y_pred_test = y_pred[test_mask_expanded]
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)

def save_sample_level_predictions_multi_metric(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    metric_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_mixed_irt: Optional[np.ndarray] = None,
    pred_1pl_irt: Optional[np.ndarray] = None
):
    """
    Save sample-level test set predictions for multi-metric data.
    Each row = one (model, question, metric) triple.
    """
    test_rows, test_cols = np.where(test_mask)
    n_samples = len(test_rows)
    M = len(metric_names)

    # Flatten true values and predictions into 1D lists
    true_values = []
    global_preds = []
    row_preds = []
    col_preds = []
    mixed_irt_preds = []
    irt_1pl_preds = []

    for idx in range(n_samples):
        i = test_rows[idx]
        j = test_cols[idx]
        for m in range(M):
            true_values.append(Y[i, j, m])
            global_preds.append(pred_global[i, j, m])
            row_preds.append(pred_row[i, j, m])
            col_preds.append(pred_col[i, j, m])
            if pred_mixed_irt is not None:
                mixed_irt_preds.append(pred_mixed_irt[i, j, m])
            else:
                mixed_irt_preds.append(np.nan)
            if pred_1pl_irt is not None and m == 0:  # Only has values for the first metric (bertscore_F1)
                irt_1pl_preds.append(pred_1pl_irt[i, j, m])
            elif pred_1pl_irt is not None:
                irt_1pl_preds.append(np.nan)
            else:
                irt_1pl_preds.append(np.nan)

    # Build DataFrame
    pred_df = pd.DataFrame({
        "model_idx": np.repeat(test_rows, M),
        "model_name": [model_names[r] for r in np.repeat(test_rows, M)],
        "question_idx": np.repeat(test_cols, M),
        "question_name": [question_names[c] for c in np.repeat(test_cols, M)],
        "metric_idx": np.tile(np.arange(M), n_samples),
        "metric_name": np.tile(metric_names, n_samples),
        "true_value": true_values,
        "global_mean_pred": global_preds,
        "model_mean_pred": row_preds,
        "question_mean_pred": col_preds,
        "mixed_metric_irt_pred": mixed_irt_preds,
        "1pl_irt_pred": irt_1pl_preds,
        "train_ratio": train_subset_ratio,
        "repetition": rep + 1
    })

    os.makedirs(pred_dir, exist_ok=True)
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"✅ Sample-level prediction results saved (multi-metric): {save_path}")
    print(f"   Total {len(pred_df)} test set records")

def save_irt_trace_multi_metric(
    trace_dir: str,
    trace: az.InferenceData,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full mixed-metric IRT MCMC trace to NetCDF file.
    """
    save_path = os.path.join(
        trace_dir,
        f"mixed_metric_irt_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    az.to_netcdf(trace, save_path)
    print(f"Mixed-metric IRT trace saved: {save_path}")

# =========================
# 6. Experiment Main Pipeline for Mixed-metric
# =========================
def run_experiment_mixed_metric(
    input_csv_paths: List[str],
    output_root_dir: str,
    question_id_col: Optional[str or int] = 0,
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.5, 1.0],
    rep_count: int = 3,
    irt_draws: int = 1000,
    irt_tune: int = 1000,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
):
    """
    Mixed-metric continuous IRT experiment main pipeline.
    """

    # 1. Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
    }
    for d in output_dirs.values():
        os.makedirs(d, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    # 2. Read and preprocess multiple metric data
    print("\n=== Step 1/5: Read & Preprocess Multiple Metrics ===")
    Y, model_names, question_names, metric_names = preprocess_multiple_metrics(
        input_csv_paths,
        question_id_col=question_id_col,
        zero_row_ratio_threshold=0.3,
        epsilon=1e-6
    )
    N, J, M = Y.shape
    print(f"Preprocessed data shape: {N} models × {J} questions × {M} metrics")

    # 3. Split fixed test set
    print("\n=== Step 2/5: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J, M),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    # 4. Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios},
        "Mixed-metric IRT": {r: [] for r in train_ratios},
        "1PL IRT (bertscore_F1)": {r: [] for r in train_ratios}
    }

    # 5. Iterate through training ratios and repetitions
    for train_ratio in train_ratios:
        print(f"\n=== Step 3/5: Train Ratio = {train_ratio:.3f} ===")
        train_masks_rep = []

        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep

            # Sample training set
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            # Baseline predictions
            print("Running baseline methods...")
            pred_global = predict_global_mean(Y, train_mask)
            mse_global = calculate_mse_multi_metric(Y, pred_global, test_mask)
            pred_row = predict_row_mean(Y, train_mask)
            mse_row = calculate_mse_multi_metric(Y, pred_row, test_mask)
            pred_col = predict_col_mean(Y, train_mask)
            mse_col = calculate_mse_multi_metric(Y, pred_col, test_mask)

            # Train mixed-metric IRT model
            trace_file_path = os.path.join(
                output_dirs["irt_trace"],
                f"mixed_metric_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_file_path):
                print(f"Found existing trace file, loading from {trace_file_path} ...")
                trace, params = load_trace_and_params(trace_file_path)
            else:
                print("Running mixed-metric continuous IRT model...")
                trace, params = fit_mixed_metric_irt(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores,
                    seed=rep_seed
                )
                save_irt_trace_multi_metric(output_dirs["irt_trace"], trace, train_ratio, rep)

            pred_mixed_irt = predict_from_mixed_metric_irt(params)
            mse_mixed_irt = calculate_mse_multi_metric(Y, pred_mixed_irt, test_mask)
            print(f"Mixed-metric IRT MSE: {mse_mixed_irt:.6f}")

            # Train single-metric 1PL IRT model (using bertscore_F1 metric)
            print("Running 1PL IRT model (bertscore_F1)...")
            trace_1pl_file_path = os.path.join(
                output_dirs["irt_trace"],
                f"1pl_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
            )
            if os.path.exists(trace_1pl_file_path):
                print(f"Found existing 1PL trace file, loading from {trace_1pl_file_path} ...")
                trace_1pl = az.from_netcdf(trace_1pl_file_path)
                # Extract parameters from trace
                params_1pl = {
                    "theta": trace_1pl.posterior["theta"].mean(dim=["chain", "draw"]).values,
                    "b": trace_1pl.posterior["b"].mean(dim=["chain", "draw"]).values
                }
            else:
                trace_1pl, params_1pl = fit_irt_1pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores,
                    seed=rep_seed
                )
                # Save 1PL IRT trace
                save_path_1pl = os.path.join(
                    output_dirs["irt_trace"],
                    f"1pl_irt_ratio_{train_ratio:.3f}_rep{rep+1}.nc"
                )
                trace_1pl.to_netcdf(save_path_1pl)
                print(f"1PL IRT trace saved: {save_path_1pl}")

            pred_1pl = predict_from_irt_1pl(params_1pl)
            # Expand 2D prediction to 3D, only has values for the first metric (bertscore_F1)
            pred_1pl_3d = np.zeros_like(Y)
            pred_1pl_3d[:, :, 0] = pred_1pl
            mse_1pl = calculate_mse_multi_metric(Y, pred_1pl_3d, test_mask)
            print(f"1PL IRT (bertscore_F1) MSE: {mse_1pl:.6f}")

            # Record MSE
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            mse_summary["Mixed-metric IRT"][train_ratio].append(mse_mixed_irt)
            mse_summary["1PL IRT (bertscore_F1)"][train_ratio].append(mse_1pl)

            # Save sample-level predictions
            save_sample_level_predictions_multi_metric(
                pred_dir=output_dirs["predictions"],
                Y=Y,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                metric_names=metric_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_mixed_irt=pred_mixed_irt,
                pred_1pl_irt=pred_1pl_3d
            )

            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            print(f"Mixed-metric IRT: {mse_mixed_irt:.6f}")
            print(f"1PL IRT (bertscore_F1): {mse_1pl:.6f}")

        # Save data split records
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

    # 6. Save MSE summary and plotting



