#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
XSUM Dataset F1-Score Prediction Experiment
Experiment Design:
- Fixed Test Set: 5% of total data (randomly sampled once)
- Training Set: 10%, 20%, ..., 100% of remaining 95% data (variable ratio)
- Methods: Naive(Global Mean), Average Score(Row Mean), Difficulty(Col Mean), IRT(1PL/2PL/3PL)
- Repetitions: 3 independent runs for each training ratio
- Outputs: Data split records, full MCMC traces, sample-wise predictions, MSE metrics/plots
"""

import os
import time
import pickle
import warnings
from typing import Tuple, Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pymc as pm
import arviz as az
from scipy.special import expit, logit

# Global Config (Critical for Consistency & Plot Requirements)
plt.rcParams["font.family"] = ["Arial", "Helvetica"]  # 强制使用英文字体（避免中文乱码+符合实验图要求）
plt.rcParams["axes.unicode_minus"] = False  # 正确显示负号
plt.rcParams["figure.dpi"] = 150  # 图表分辨率（150DPI兼顾清晰度与文件大小）
warnings.filterwarnings("ignore")  # 屏蔽无关警告（如pymc的采样提示）

# =========================
# 1. Core I/O & Preprocessing
# =========================
def robust_read_csv(path: str) -> pd.DataFrame:
    """Read CSV with multiple encoding attempts (handle different file formats)."""
    encodings = ["utf-8", "utf-8-sig", "gbk", "latin1"]  # 覆盖常见编码格式
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
    raise ValueError(f"Failed to read CSV: {str(last_err)}")


def preprocess_matrix(
    df: pd.DataFrame,
    question_id_col: Optional[str or int] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6,
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess input data (rows=questions, cols=models) into model-centric format:
    1. Filter questions with too many zeros
    2. Handle 0/1 values (avoid logit overflow)
    3. Transpose to (rows=models, cols=questions)
    Returns: preprocessed_df, Y_array, model_names, question_names
    """
    work = df.copy()  # 复制原始数据，避免修改源文件

    # Step 1: Set question ID as index (if provided)
    if question_id_col is not None:
        work = work.set_index(question_id_col)  # 将题目ID列设为索引（方便后续追溯题目）
    original_question_names = work.index.tolist()  # 记录原始题目名称

    # Step 2: Filter questions with excessive zeros (e.g., >30% zeros)
    zero_ratio = (work == 0).sum(axis=1) / max(1, work.shape[1])  # 计算每道题的零值占比
    work = work[zero_ratio <= zero_row_ratio_threshold].copy()  # 剔除零值过多的题目（无区分度）
    filtered_question_names = work.index.tolist()
    print(f"Filtered questions: {len(original_question_names)} → {len(filtered_question_names)}")

    # Step 3: Numeric conversion + fill missing values
    work = work.apply(pd.to_numeric, errors="coerce").fillna(0.0)  # 转为数值型，缺失值填0（合理假设：未记录视为0分）

    # Step 4: Clip to (epsilon, 1-epsilon) (avoid logit(0) or logit(1) overflow)
    # 原因：IRT模型需用logit转换F1分数（0→-inf、1→+inf会导致计算崩溃），用1e-6/1-1e-6截断
    work = work.replace(0.0, epsilon).clip(lower=epsilon, upper=1 - epsilon)

    # Step 5: Transpose (rows=models, cols=questions)
    # 原始数据：行=题目、列=模型 → 转置后：行=模型、列=题目（符合IRT“被试（模型）×题目”的标准格式）
    work_transposed = work.T
    model_names = work_transposed.index.tolist()  # 记录模型名称（后续用于结果追溯）
    Y = work_transposed.values  # 转为numpy数组（便于矩阵运算）

    return work_transposed, Y, model_names, filtered_question_names

# =========================
# 2. Data Split (Fixed Test + Variable Training)
# =========================
def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.05,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split data into FIXED test set (5%) and remaining pool (95%)—only run once.
    Returns: test_mask (True=test sample), remaining_mask (True=non-test sample)
    """
    rng = np.random.default_rng(test_seed)  # 固定随机种子（确保测试集唯一且可复现）
    test_mask = rng.random(Y_shape) < test_ratio  # 生成与数据同形状的布尔矩阵（True=测试样本）
    remaining_mask = ~test_mask  # 剩余95%数据（用于后续抽取不同比例的训练集）
    print(f"Fixed test set size: {test_mask.sum()} samples ({test_ratio*100:.1f}%)")
    print(f"Remaining pool size: {remaining_mask.sum()} samples ({(1-test_ratio)*100:.1f}%)")
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,  # Ratio of remaining pool (e.g., 0.1 = 10% of 95%)
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining pool (exclude fixed test set) for 1 repetition.
    Returns: train_mask (True=training sample)
    """
    rng = np.random.default_rng(rep_seed)  # 每个重复用不同种子（确保训练集独立）
    train_mask = np.zeros_like(remaining_mask, dtype=bool)  # 初始化训练集掩码

    # 步骤1：获取“剩余池”的索引（即非测试集的样本位置）
    remaining_rows, remaining_cols = np.where(remaining_mask)  # rows=模型索引，cols=题目索引
    n_remaining = len(remaining_rows)  # 剩余池总样本数
    n_train = int(n_remaining * train_subset_ratio)  # 本次需抽取的训练样本数

    # 步骤2：无放回抽样（避免同一“模型-题目”对重复进入训练集）
    selected_idx = rng.choice(n_remaining, size=n_train, replace=False)
    # 标记选中的样本为训练集
    train_mask[remaining_rows[selected_idx], remaining_cols[selected_idx]] = True

    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    # 关键：添加长度校验，避免后续索引溢出
    N = len(model_names)  # 模型数 = model_names长度
    J = len(question_names)  # 题目数 = question_names长度
    for rep_idx, train_mask in enumerate(train_masks):
        # 检查训练集掩码的形状是否与名称列表匹配
        if train_mask.shape != (N, J):
            raise ValueError(
                f"训练集掩码形状{train_mask.shape}与数据维度({N},{J})不匹配！"
                f"model_names长度={N}，question_names长度={J}，请检查数据预处理逻辑"
            )
    
    # 后续正常处理数据划分记录...
    for rep_idx, train_mask in enumerate(train_masks):
        # 1. 提取训练集的行（模型）和列（题目）索引
        train_rows, train_cols = np.where(train_mask)
        # 2. 提取测试集的行（模型）和列（题目）索引
        test_rows, test_cols = np.where(test_mask)
        
        # 生成训练集记录（此时索引r肯定在model_names范围内）
        train_records = pd.DataFrame({
            "model_idx": train_rows,
            "model_name": [model_names[r] for r in train_rows],  # 不会再超出范围
            "question_idx": train_cols,
            "question_name": [question_names[c] for c in train_cols],  # 不会再超出范围
            "set_type": "train",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # 生成测试集记录
        test_records = pd.DataFrame({
            "model_idx": test_rows,
            "model_name": [model_names[r] for r in test_rows],
            "question_idx": test_cols,
            "question_name": [question_names[c] for c in test_cols],
            "set_type": "test",
            "train_ratio": train_subset_ratio,
            "repetition": rep_idx + 1
        })
        
        # 保存记录
        all_records = pd.concat([train_records, test_records], ignore_index=True)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.1f}_rep{rep_idx+1}.csv")
        all_records.to_csv(save_path, index=False)
    print(f"数据划分记录已保存（训练比例：{train_subset_ratio:.1f}）")

# =========================
# 3. Prediction Methods (Baselines + IRT Models)
# =========================
def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Naive Method: Predict all samples with global mean of training data."""
    # 计算训练集所有样本的均值（若训练集为空，用0.5作为默认值，避免报错）
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5
    # 生成与原始数据同形状的预测矩阵（所有值均为全局均值）
    return np.full_like(Y, global_mean)


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Average Score: Predict each model's samples with its training mean."""
    N, J = Y.shape  # N=模型数，J=题目数
    row_means = np.zeros(N)  # 存储每个模型的训练集均值
    # 全局均值作为 fallback（若某模型无训练样本，用全局均值替代）
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    # 遍历每个模型，计算其在训练集中的均值
    for i in range(N):
        model_train_mask = train_mask[i, :]  # 第i个模型的训练集掩码
        if model_train_mask.any():
            row_means[i] = Y[i, model_train_mask].mean()  # 该模型的训练均值
        else:
            row_means[i] = global_mean  # 无训练样本时用全局均值

    # 将模型均值广播为(N,J)矩阵（每个模型的所有题目预测值均为其训练均值）
    return np.tile(row_means.reshape(-1, 1), (1, J))


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """Difficulty Modeling: Predict each question's samples with its training mean."""
    N, J = Y.shape
    col_means = np.zeros(J)  # 存储每个题目的训练集均值
    global_mean = Y[train_mask].mean() if train_mask.any() else 0.5

    # 遍历每个题目，计算其在训练集中的均值
    for j in range(J):
        question_train_mask = train_mask[:, j]  # 第j个题目的训练集掩码
        if question_train_mask.any():
            col_means[j] = Y[question_train_mask, j].mean()  # 该题目的训练均值
        else:
            col_means[j] = global_mean  # 无训练样本时用全局均值

    # 将题目均值广播为(N,J)矩阵（每个题目的所有模型预测值均为其训练均值）
    return np.tile(col_means.reshape(1, -1), (N, 1))


# -------------------------
# IRT Models (1PL/2PL/3PL) with Full Trace Saving
# -------------------------
def fit_irt_1pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    1PL IRT Model: P(Y_ij) = expit(theta_i - b_j)
    Theta: Model ability; b: Question difficulty
    Returns: Full MCMC trace (for post-hoc analysis), posterior mean parameters
    """
    N, J = Y.shape
    # 截断F1分数（避免logit溢出，与预处理一致）
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)  # 转换为logit尺度（符合IRT模型的正态假设）

    # 自动适配CPU核心数（避免过度占用资源）
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    # 构建1PL模型（基于PyMC的贝叶斯框架）
    with pm.Model() as irt_1pl:
        # 1. 定义参数先验（基于IRT领域常见假设）
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # 模型能力：正态分布（均值0，标准差1）
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # 题目难度：正态分布（均值0，标准差1）

        # 2. 线性预测器（logit尺度的期望）
        mu = theta[:, None] - b[None, :]  # 广播运算：(N,1) - (1,J) → (N,J)

        # 3. 似然函数（仅用训练集数据拟合）
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC采样（核心步骤：估计参数后验分布）
        start_time = time.time()
        trace = pm.sample(
            draws=draws,          # 每个链的采样数（1000个有效样本，确保估计稳定）
            tune=tune,            # 预热迭代数（1000次，让采样器收敛到目标分布）
            chains=chains,        # 并行链数（4个链，用于收敛性检查）
            cores=cores,          # 并行核心数（匹配链数，加速采样）
            target_accept=0.95,   # 目标接受率（高接受率提升采样效率，减少自相关）
            progressbar=True,    # 关闭进度条（减少输出干扰）
            return_inferencedata=True  # 返回ArviZ格式的trace（便于后续分析）
        )
        print(f"1PL IRT sampling time: {time.time()-start_time:.1f}s")

    # 收敛性检查（关键：r_hat < 1.01 表示采样收敛，结果可靠）
    r_hat = az.summary(trace, var_names=["theta", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 1PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01 (potential convergence issue)")

    # 提取参数后验均值（用于生成预测）
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params

def fit_irt_2pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    """
    2PL IRT Model: P(Y_ij) = expit(a_j*(theta_i - b_j))
    Theta: Model ability; a: Question discrimination; b: Question difficulty
    Returns: Full MCMC trace, posterior mean parameters
    """
    N, J = Y.shape
    Y_clip = np.clip(Y, 1e-6, 1 - 1e-6)
    logit_Y = logit(Y_clip)

    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_2pl:
        # 1. 参数先验（相比1PL新增“区分度a_j”）
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)  # 模型能力（同1PL）
        a = pm.LogNormal("a", mu=0.0, sigma=0.5, shape=J)        # 题目区分度：LogNormal确保a_j>0
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)          # 题目难度（同1PL）

        # 2. 线性预测器（区分度a_j加权能力与难度的差异）
        mu = a[None, :] * theta[:, None] - b[None, :]  # (1,J)*(N,J) → (N,J)

        # 3. 似然函数（同1PL，仅用训练集拟合）
        pm.Normal("obs", mu=mu[train_mask], sigma=1.0, observed=logit_Y[train_mask])

        # 4. MCMC采样
        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"2PL IRT sampling time: {time.time()-start_time:.1f}s")

    # 收敛性检查（新增对区分度a的检查）
    r_hat = az.summary(trace, var_names=["theta", "a", "b"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 2PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    # 提取后验均值
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values
    }
    return trace, params


def fit_irt_3pl(
    Y: np.ndarray,
    train_mask: np.ndarray,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    cores: Optional[int] = None
) -> Tuple[az.InferenceData, Dict[str, np.ndarray]]:
    N, J = Y.shape  # N=模型数，J=题目数
    # 数据预处理：裁剪极端值+logit转换（避免expit/logit溢出）
    eps = 1e-6
    Y_clip = np.clip(Y, eps, 1 - eps)
    logit_Y = logit(Y_clip)

    # 自动适配CPU核心数
    if cores is None:
        cores = min(os.cpu_count() or 1, chains)

    with pm.Model() as irt_3pl:
        # 1. 参数先验（完全沿用你的设置）
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)          # 模型能力
        a = pm.Normal("a", mu=1.0, sigma=0.5, shape=J)                  # 区分度（你的先验设置）
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)                  # 难度（你的先验设置）
        c = pm.Beta("c", alpha=1.0, beta=5.0, shape=J)                   # 3PL核心：猜测参数
        sigma = pm.HalfNormal("sigma", sigma=1.0, shape=J)              # 离散度（你的先验设置）
        # 可选：若需要范围约束，启用下方注释（你的备选设置）
        # sigma = pm.Uniform("sigma", lower=0.5, upper=1.5, shape=J)

        # 2. 3PL均值计算（完全沿用你的mu计算逻辑，新增猜测项）
        # 基础均值项：mu_base = a_j*theta_i - b_j（你的原公式）
        mu_base = a[None, :] * theta[:, None] - b[None, :]
        # 3PL概率转换：P = c_j + (1-c_j)*expit(mu_base)
        prob = c[None, :] + (1 - c[None, :]) * pm.math.sigmoid(mu_base)  # sigmoid=expit，兼容PyMC张量
        prob_clip = pm.math.clip(prob, eps, 1 - eps)                     # 避免logit溢出
        mu = pm.math.logit(prob_clip)                                    # 转换为logit尺度（适配似然）

        # 3. 数据形状处理（完全沿用你的展平逻辑，仅用训练集数据）
        # 提取训练集的logit观测值、均值、离散度
        logit_y_obs = logit_Y[train_mask].flatten()  # 训练集观测值→1D数组（你的处理方式）
        mu_obs = mu[train_mask].flatten()            # 训练集均值→1D数组（你的处理方式）
        # 离散度扩展：每个题目重复N次→匹配1D数组长度（你的处理方式）
        sigma_obs = sigma.repeat(N)[train_mask.flatten()]  

        # （可选）打印形状检查（调试用，沿用你的调试逻辑）
        print(f"训练集logit观测值形状: {logit_y_obs.shape}")
        print(f"训练集均值形状: {mu_obs.shape}")
        print(f"训练集离散度形状: {sigma_obs.shape}")

        # 4. 似然函数（logit-normal分布，完全沿用你的设置）
        pm.Normal("y_obs", mu=mu_obs, sigma=sigma_obs, observed=logit_y_obs)

        # 5. MCMC采样（保留原有逻辑，确保输入输出一致）
        start_time = time.time()
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )
        print(f"3PL IRT sampling time: {time.time()-start_time:.1f}s")

    # 6. 收敛性检查（新增sigma参数检查，确保所有参数收敛）
    r_hat = az.summary(trace, var_names=["theta", "a", "b", "c", "sigma"])["r_hat"]
    if (r_hat > 1.01).any():
        print(f"⚠️ 3PL IRT: {sum(r_hat>1.01)} parameters with r_hat>1.01")

    # 7. 提取后验均值（包含sigma，保留原有输出结构）
    params = {
        "theta": trace.posterior["theta"].mean(dim=["chain", "draw"]).values,
        "a": trace.posterior["a"].mean(dim=["chain", "draw"]).values,      # 沿用你的a参数
        "b": trace.posterior["b"].mean(dim=["chain", "draw"]).values,      # 沿用你的b参数
        "c": trace.posterior["c"].mean(dim=["chain", "draw"]).values,      # 3PL猜测参数
        "sigma": trace.posterior["sigma"].mean(dim=["chain", "draw"]).values  # 沿用你的sigma参数
    }
    return trace, params


def predict_from_irt(params: Dict[str, np.ndarray], model_type: str) -> np.ndarray:
    """
    Generate F1 predictions from IRT posterior parameters.
    Input: params (theta/a/b/c), model_type ("1pl"/"2pl"/"3pl")
    Output: Predictions (N,J) → F1 scores in [0,1]
    """
    theta = params["theta"].reshape(-1, 1)  # (N,1)：每个模型的能力值
    if model_type == "1pl":
        b = params["b"].reshape(1, -1)      # (1,J)：每个题目的难度
        return expit(theta - b)             # 1PL预测公式：expit(θ_i - b_j)
    elif model_type == "2pl":
        a = params["a"].reshape(1, -1)      # (1,J)：每个题目的区分度
        b = params["b"].reshape(1, -1)
        return expit(a * (theta - b))       # 2PL预测公式：expit(a_j*(θ_i - b_j))
    elif model_type == "3pl":
        a = params["a"].reshape(1, -1)
        b = params["b"].reshape(1, -1)
        c = params["c"].reshape(1, -1)      # (1,J)：每个题目的猜测参数
        return c + (1 - c) * expit(a * (theta - b))  # 3PL预测公式
    else:
        raise ValueError(f"Unsupported IRT model type: {model_type}")

# =========================
# 4. MSE Calculation & Result Saving
# =========================
def calculate_mse(y_true: np.ndarray, y_pred: np.ndarray, test_mask: np.ndarray) -> float:
    """
    Calculate MSE (Mean Squared Error) only on test set samples.
    Input:
        y_true: Original F1 scores (N,J)
        y_pred: Predicted F1 scores (N,J)
        test_mask: Boolean matrix (True=test sample)
    Output: MSE value (lower = better prediction)
    """
    # 仅提取测试集的真实值与预测值
    y_true_test = y_true[test_mask]
    y_pred_test = y_pred[test_mask]
    # 计算MSE：平均平方误差
    mse = np.mean((y_true_test - y_pred_test) ** 2)
    return round(mse, 6)  # 保留6位小数，便于对比


def save_sample_level_predictions_old(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_1pl: np.ndarray,
    pred_2pl: np.ndarray,
    pred_3pl: np.ndarray
):
    """
    Save sample-level test set predictions to CSV (for post-hoc MSE calculation).
    Each row = 1 test sample (model-question pair), columns = true value + all methods' predictions.
    """
    # 1. 提取测试集的索引、真实值与所有预测值
    test_rows, test_cols = np.where(test_mask)  # 测试集的模型/题目索引
    true_values = Y[test_mask].tolist()         # 真实F1分数
    global_preds = pred_global[test_mask].tolist()  # 全局均值预测
    row_preds = pred_row[test_mask].tolist()        # 模型均值预测
    col_preds = pred_col[test_mask].tolist()        # 题目均值预测
    irt1pl_preds = pred_1pl[test_mask].tolist()     # 1PL预测
    irt2pl_preds = pred_2pl[test_mask].tolist()     # 2PL预测
    irt3pl_preds = pred_3pl[test_mask].tolist()     # 3PL预测

    # 2. 构建DataFrame（包含可追溯信息）
    pred_df = pd.DataFrame({
        "model_idx": test_rows,
        "model_name": [model_names[r] for r in test_rows],
        "question_idx": test_cols,
        "question_name": [question_names[c] for c in test_cols],
        "true_f1": true_values,
        "pred_global_mean": global_preds,
        "pred_model_mean": row_preds,
        "pred_question_mean": col_preds,
        "pred_irt_1pl": irt1pl_preds,
        "pred_irt_2pl": irt2pl_preds,
        "pred_irt_3pl": irt3pl_preds,
        "train_ratio": train_subset_ratio,
        "repetition": rep + 1  # 重复次数（1-3）
    })

    # 3. 保存CSV（按“训练比例+重复次数”命名，便于后续筛选）
    save_path = os.path.join(pred_dir, f"predictions_ratio_{train_subset_ratio:.1f}_rep{rep+1}.csv")
    pred_df.to_csv(save_path, index=False)
    print(f"Sample predictions saved: {save_path}")

import pandas as pd
import numpy as np

def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_1pl: Optional[np.ndarray] = None,  # 允许为None（未选择1PL时）
    pred_2pl: Optional[np.ndarray] = None,  # 允许为None（未选择2PL时）
    pred_3pl: Optional[np.ndarray] = None   # 允许为None（未选择3PL时）
):
    """
    保存样本级预测结果（支持部分IRT模型未选择的情况，处理None值）
    """
    # 1. 提取测试集的行（模型）、列（题目）索引和真实值
    test_rows, test_cols = np.where(test_mask)  # 测试集非零位置的索引
    true_values = Y[test_mask].tolist()        # 真实值
    global_preds = pred_global[test_mask].tolist()  # 全局均值预测
    row_preds = pred_row[test_mask].tolist()        # 模型均值预测
    col_preds = pred_col[test_mask].tolist()        # 题目均值预测

    # 2. 处理IRT模型预测结果（核心修改：判断是否为None，避免报错）
    # 1PL预测：若为None，用空列表填充；否则取测试集切片
    irt1pl_preds = pred_1pl[test_mask].tolist() if pred_1pl is not None else [np.nan] * len(test_rows)
    # 2PL预测：同上
    irt2pl_preds = pred_2pl[test_mask].tolist() if pred_2pl is not None else [np.nan] * len(test_rows)
    # 3PL预测：同上
    irt3pl_preds = pred_3pl[test_mask].tolist() if pred_3pl is not None else [np.nan] * len(test_rows)

    # 3. 构建预测结果DataFrame（包含所有方法，未选择的模型用NaN表示）
    pred_df = pd.DataFrame({
        # 基础追溯信息
        "model_idx": test_rows,                  # 模型索引
        "model_name": [model_names[r] for r in test_rows],  # 模型名
        "question_idx": test_cols,               # 题目索引
        "question_name": [question_names[c] for c in test_cols],  # 题目名
        "train_ratio": train_subset_ratio,       # 当前训练比例
        "repetition": rep + 1,                   # 重复次数（1-based）
        # 真实值与各方法预测值
        "true_value": true_values,               # 真实值
        "global_mean_pred": global_preds,        # 全局均值预测
        "model_mean_pred": row_preds,            # 模型均值预测
        "question_mean_pred": col_preds,         # 题目均值预测
        "irt_1pl_pred": irt1pl_preds,            # 1PL预测（未选择则为NaN）
        "irt_2pl_pred": irt2pl_preds,            # 2PL预测（未选择则为NaN）
        "irt_3pl_pred": irt3pl_preds             # 3PL预测（未选择则为NaN）
    })

    # 4. 创建输出目录（确保目录存在）
    os.makedirs(pred_dir, exist_ok=True)

    # 5. 保存CSV文件（文件名包含训练比例和重复次数，便于区分）
    save_filename = f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    save_path = os.path.join(pred_dir, save_filename)
    pred_df.to_csv(save_path, index=False, encoding="utf-8")

    print(f"✅ 样本级预测结果已保存：{save_path}")
    print(f"   共{len(pred_df)}条测试集记录")


def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save full IRT MCMC trace to file (supports post-hoc analysis: posterior median/95%CI).
    Format: NetCDF (ArviZ native format, easy to load later)
    """
    # 文件名格式：irt_[类型]_比例_重复次数.nc
    save_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.1f}_rep{rep+1}.nc"
    )
    # 保存NetCDF文件（ArviZ推荐格式，保留所有采样信息）
    az.to_netcdf(trace, save_path)
    print(f"IRT {model_type} trace saved: {save_path}")


# =========================
# 5. Result Visualization (MSE Comparison Plot)
# =========================
def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE vs. Training Data Ratio for all methods (no Chinese).
    Input:
        mse_dict: Key=method name, Value=MSE list (length=len(train_ratios))
        train_ratios: List of training data ratios (e.g., [0.1, 0.2, ..., 1.0])
        save_path: Path to save plot (PNG)
    """
    # 定义颜色与线型（确保不同方法易区分）
    method_styles = {
        "Global Mean": ("blue", "solid"),
        "Model Mean": ("orange", "dashed"),
        "Question Mean": ("green", "dashdot"),
        "IRT-1PL": ("red", "solid"),
        "IRT-2PL": ("purple", "dashed"),
        "IRT-3PL": ("brown", "dashdot")
    }

    # 创建画布
    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles[method]
        # 绘制折线图（带标记点，便于读取具体数值）
        plt.plot(
            train_ratios,
            mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    # 图表配置（无中文，符合实验要求）
    plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
    plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)  # 图例放在右上角，避免遮挡
    plt.grid(True, alpha=0.3)  # 网格线（便于读取数值）
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)  # x轴刻度为训练比例
    plt.yticks(fontsize=10)

    # 保存图片（高分辨率，无白边）
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary to CSV (mean ± std of 3 repetitions for each method & ratio).
    """
    # 构建汇总表：行=训练比例，列=方法（格式：均值±标准差）
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"  # 均值±标准差格式
        rows.append(row)

    # 保存CSV
    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")

# =========================
# 6. Experiment Main Pipeline
# =========================
def run_experiment(
    input_csv_path: str,
    output_root_dir: str,
    question_id_col: Optional[str or int] = 0,  # 若CSV第一列为题目ID，传0或列名
    test_ratio: float = 0.05,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    rep_count: int = 3,  # 每个训练比例重复3次
    irt_draws: int = 10,
    irt_tune: int = 10,
    irt_chains: int = 4,
    irt_cores: Optional[int] = None,
    selected_irt_models: List[str] = ["1pl"]  # 新增：指定要使用的IRT模型
):
    """
    实验主流程：整合数据读取、预处理、划分、建模、评估、结果保存
    Input:
        input_csv_path: 输入CSV路径（行=题目，列=模型）
        output_root_dir: 输出根目录（会自动创建子目录）
        question_id_col: 题目ID列（无则传None）
        selected_irt_models: 选择要运行的IRT模型，可选值：["1pl", "2pl", "3pl"]的任意组合
        其他参数：实验设计相关（测试集比例、训练比例列表、重复次数、IRT采样参数）
    """
    # 新增：校验输入的IRT模型是否合法
    valid_irt_models = ["1pl", "2pl", "3pl"]
    for model in selected_irt_models:
        if model not in valid_irt_models:
            raise ValueError(
                f"不支持的IRT模型：{model}！请从以下合法模型中选择：{valid_irt_models}"
            )
    print(f"已选择运行的IRT模型：{selected_irt_models}")

    # -------------------------
    # Step 1: 创建输出目录（按功能分类，结构清晰）
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),  # 数据划分记录
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),  # 样本级预测
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),  # IRT trace
        "metrics": os.path.join(output_root_dir, "04_metrics"),  # MSE汇总与图表
    }
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    print(f"Output directories created: {list(output_dirs.values())}")

    # -------------------------
    # Step 2: 数据读取与预处理
    print("\n=== Step 1/6: Read & Preprocess Data ===")
    raw_df = robust_read_csv(input_csv_path)
    preprocessed_df, Y, model_names, question_names = preprocess_matrix(
        df=raw_df,
        question_id_col=question_id_col,
        zero_row_ratio_threshold=0.3,
        epsilon=1e-6
    )
    N, J = Y.shape
    print(f"Preprocessed data shape: {N} models × {J} questions")

    # -------------------------
    # Step 3: 划分固定测试集（仅运行一次）
    print("\n=== Step 2/6: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )

    # -------------------------
    # Step 4: 初始化MSE记录字典（仅包含选择的IRT模型）
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios}
    }
    # 仅为选择的IRT模型添加MSE记录键
    for model in selected_irt_models:
        mse_summary[f"IRT-{model.upper()}"] = {r: [] for r in train_ratios}

    # -------------------------
    # Step 5: 遍历所有训练比例 + 重复次数（核心实验循环）
    for train_ratio in train_ratios:
        print(f"\n=== Step 3/6: Train Ratio = {train_ratio:.3f} ===")
        
        # 存储当前训练比例的3次重复训练集掩码
        train_masks_rep = []
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep  # 每个重复用不同种子

            # -------------------------
            # Substep 5.1: 采样当前重复的训练集
            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            train_masks_rep.append(train_mask)
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")

            # -------------------------
            # Substep 5.2: 基线方法预测 + MSE计算
            print("Running baseline methods...")
            # 1. 全局均值
            pred_global = predict_global_mean(Y, train_mask)
            mse_global = calculate_mse(Y, pred_global, test_mask)
            # 2. 模型均值
            pred_row = predict_row_mean(Y, train_mask)
            mse_row = calculate_mse(Y, pred_row, test_mask)
            # 3. 题目均值
            pred_col = predict_col_mean(Y, train_mask)
            mse_col = calculate_mse(Y, pred_col, test_mask)

            # -------------------------
            # Substep 5.3: 仅运行选择的IRT模型
            print(f"Running selected IRT models: {selected_irt_models}...")
            # 初始化预测结果变量为None
            pred_1pl = None
            pred_2pl = None
            pred_3pl = None
            
            # 1. 若选择1PL模型
            if "1pl" in selected_irt_models:
                trace_1pl, params_1pl = fit_irt_1pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_1pl = predict_from_irt(params_1pl, "1pl")
                mse_1pl = calculate_mse(Y, pred_1pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_1pl, "1pl", train_ratio, rep)
                print(f"IRT-1PL MSE: {mse_1pl:.6f}")

            # 2. 若选择2PL模型
            if "2pl" in selected_irt_models:
                trace_2pl, params_2pl = fit_irt_2pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_2pl = predict_from_irt(params_2pl, "2pl")
                mse_2pl = calculate_mse(Y, pred_2pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_2pl, "2pl", train_ratio, rep)
                print(f"IRT-2PL MSE: {mse_2pl:.6f}")

            # 3. 若选择3PL模型
            if "3pl" in selected_irt_models:
                trace_3pl, params_3pl = fit_irt_3pl(
                    Y=Y,
                    train_mask=train_mask,
                    draws=irt_draws,
                    tune=irt_tune,
                    chains=irt_chains,
                    cores=irt_cores
                )
                pred_3pl = predict_from_irt(params_3pl, "3pl")
                mse_3pl = calculate_mse(Y, pred_3pl, test_mask)
                save_irt_trace(output_dirs["irt_trace"], trace_3pl, "3pl", train_ratio, rep)
                print(f"IRT-3PL MSE: {mse_3pl:.6f}")

            # -------------------------
            # Substep 5.4: 记录当前重复的MSE
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            
            # 记录选择的IRT模型MSE
            if "1pl" in selected_irt_models:
                mse_summary["IRT-1PL"][train_ratio].append(mse_1pl)
            if "2pl" in selected_irt_models:
                mse_summary["IRT-2PL"][train_ratio].append(mse_2pl)
            if "3pl" in selected_irt_models:
                mse_summary["IRT-3PL"][train_ratio].append(mse_3pl)

            # -------------------------
            # Substep 5.5: 保存当前重复的样本级预测（固定参数传递）
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y,
                test_mask=test_mask,
                model_names=model_names,
                question_names=question_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_1pl=pred_1pl,  # 只传递已计算的预测结果，未选择的模型为None
                pred_2pl=pred_2pl,
                pred_3pl=pred_3pl
            )

            # 打印当前重复的MSE结果
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            if "1pl" in selected_irt_models:
                print(f"IRT-1PL: {mse_1pl:.6f}")
            if "2pl" in selected_irt_models:
                print(f"IRT-2PL: {mse_2pl:.6f}")
            if "3pl" in selected_irt_models:
                print(f"IRT-3PL: {mse_3pl:.6f}")

        # -------------------------
        # Substep 5.6: 记录当前训练比例的数据划分
        record_data_split(
            split_dir=output_dirs["split"],
            test_mask=test_mask,
            train_masks=train_masks_rep,
            train_subset_ratio=train_ratio,
            model_names=model_names,
            question_names=question_names
        )

    # -------------------------
    # Step 6: 生成最终MSE汇总与可视化
    print("\n=== Step 4/6: Generate Final Results ===")
    # 保存MSE汇总表
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)

    # 生成MSE对比图
    plot_mse_dict = {}
    # 添加基线方法
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    # 添加选择的IRT模型
    for model in selected_irt_models:
        plot_mse_dict[f"IRT-{model.upper()}"] = [
            np.mean(mse_summary[f"IRT-{model.upper()}"][r]) for r in train_ratios
        ]
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)

    # -------------------------
    # Step 7: 打印实验完成信息
    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")


INPUT_CSV_PATH = "data/response_matrix__bertscore_F1.csv"  # 输入CSV路径（行=题目，列=模型） 
OUTPUT_ROOT_DIR = "results/modeling_result_XSUM_F1"  # 输出根目录（会自动创建）
QUESTION_ID_COL = 'row_index'  # 题目ID列（若CSV第一列为题目ID，传0；无则传None）
TEST_SEED = 42  # 固定测试集种子（确保测试集可复现）
TEST_RATIO = 0.10
TRAIN_RATIOS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
#TRAIN_RATIOS = [0.5]
REP_COUNT = 1  # 每个训练比例重复次数
IRT_DRAWS = 10  # IRT模型每个链的采样数
IRT_TUNE = 10  # IRT模型预热次数
IRT_CHAINS = 12  # IRT模型并行链数
IRT_CORES = 12  # IRT模型并行核心数（None=自动适配）
SELECTED_IRT_MODELS = ["1pl"]  # 可修改为 ["1pl"] 或 ["1pl", "2pl", "3pl"]

run_experiment(
    input_csv_path=INPUT_CSV_PATH,
    output_root_dir=OUTPUT_ROOT_DIR,
    question_id_col=QUESTION_ID_COL,
    test_ratio=TEST_RATIO,
    test_seed=TEST_SEED,
    train_ratios=TRAIN_RATIOS,
    rep_count=REP_COUNT,
    irt_draws=IRT_DRAWS,
    irt_tune=IRT_TUNE,
    irt_chains=IRT_CHAINS,
    irt_cores=IRT_CORES,
    selected_irt_models=SELECTED_IRT_MODELS  # 传入选择的模型列表
)

'''
python /Users/bytedance/Desktop/QileZhang/llm/IRT/eval/Metric/modeling_0912.py 
'''