#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
from typing import Dict, List, Optional, Tuple, Union
from scipy.special import logit, expit

# Don't use Chinese fonts, use default font
plt.rcParams["axes.unicode_minus"] = False  # Display negative signs correctly

# =========================
# 1. I/O & Preprocessing
# =========================

def robust_read_csv(file_path: str) -> pd.DataFrame:
    """
    Robust CSV reading function (supports different encodings and separators)
    """
    # Try UTF-8 encoding first
    encodings = ["utf-8", "utf-8-sig", "gbk", "latin1"]
    for encoding in encodings:
        try:
            return pd.read_csv(file_path, encoding=encoding)
        except UnicodeDecodeError:
            continue
        except Exception as e:
            continue
    
    # If all encodings fail, raise an error
    raise IOError(f"Unable to read file {file_path}: all encoding attempts failed")


def preprocess_matrix(
    df: pd.DataFrame,
    question_id_col: Optional[Union[str, int]] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess matrix data: remove outliers, normalize, and extract model names and question names
    """
    # 1. Remove question ID column (if present)
    if question_id_col is not None:
        df_processed = df.drop(columns=[question_id_col], errors="ignore")
    else:
        df_processed = df.copy()
    
    # 2. Extract model names (column names) and question names (index)
    model_names = list(df_processed.columns)
    question_names = list(df_processed.index)
    
    # 3. Convert to numpy array
    Y = df_processed.values.astype(np.float64)
    
    # 4. Handle outliers (replace 0 with a small value to avoid logit transformation issues)
    Y = np.clip(Y, epsilon, 1 - epsilon)
    
    return df_processed, Y, model_names, question_names


def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.1,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split fixed test set (ensuring consistent test set across replicated experiments)
    """
    rng = np.random.default_rng(test_seed)  # Use default_rng for consistency
    N, J = Y_shape
    # Generate random mask (True=test set, False=non-test set)
    test_mask = rng.random((N, J)) < test_ratio
    # Remaining part mask (non-test set)
    remaining_mask = ~test_mask
    
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining data (outside the fixed test set)
    """
    rng = np.random.default_rng(rep_seed)  # Use default_rng for consistency
    # Generate random mask (only effective where remaining_mask=True)
    train_subset_mask = rng.random(remaining_mask.shape) < train_subset_ratio
    # Training mask = remaining mask & training subset mask
    train_mask = remaining_mask & train_subset_mask
    
    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    """
    Record data split information (test set and training set for each replication)
    """
    for rep, train_mask in enumerate(train_masks):
        # Merge test set and training set masks (1=test set, 2=training set, 0=unused)
        split_mask = np.zeros_like(test_mask)
        split_mask[test_mask] = 1  # Test set
        split_mask[train_mask] = 2  # Training set
        
        # Create DataFrame and save
        split_df = pd.DataFrame(split_mask, index=question_names, columns=model_names)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.1f}_rep{rep+1}.csv")
        split_df.to_csv(save_path)
    
    print(f"Data splits saved to {split_dir}")

# =========================
# 2. Simple Multi-Metric IRT Model
# =========================

def build_simple_multi_metric_model(N, M, J, Y, train_mask, draws=100, tune=100):
    """Build simplified multi-metric IRT model"""
    # Data processing
    logit_Y = logit(Y)
    
    # Expand training mask to 3D
    train_mask_3d = np.tile(train_mask[..., None], (1, 1, M))
    
    with pm.Model() as model:
        # Model ability parameters: global ability psi + metric-specific ability zeta
        psi = pm.Normal("psi", mu=0, sigma=1, shape=N)  # Global ability
        zeta = pm.Normal("zeta", mu=0, sigma=0.5, shape=(N, M))  # Metric-specific ability
        
        # Item parameters
        a = pm.Normal("a", mu=1.0, sigma=0.5, shape=J)  # Discrimination
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)  # Difficulty
        
        # Calculate ability parameters (for each metric)
        # theta_{ijm} = psi_i + zeta_{im}
        theta = psi[:, None, None] + zeta[:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> broadcast to (N, J, M)
        
        # Calculate predicted values
        # mu_{ijm} = a_j * theta_{ijm} - b_j
        mu = a[None, :, None] * theta - b[None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
        
        # Use only training data
        logit_Y_obs = logit_Y[train_mask_3d]
        mu_obs = mu[train_mask_3d]
        
        # Likelihood function
        pm.Normal("y_obs", mu=mu_obs, sigma=1.0, observed=logit_Y_obs)
        
        # MCMC sampling
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=12,
            cores=12,
            target_accept=0.95,
            return_inferencedata=True,
            compute_convergence_checks=False  # Simplified convergence checks
        )
    
    return trace, model


def run_simple_multimetric_irt(
    Y_3d: np.ndarray,
    sample_ratio: float,
    output_dir: str,
    metric_names: list,
    model_names: list,
    key_names: list,
    sample_kwargs: dict,
    seed: int = 20240619,
    train_mask_2d: np.ndarray | None = None,
):
    """
    Simplified multi-metric IRT modeling
    """
    N, J, M = Y_3d.shape
    _ensure_dir(output_dir)
    print(f"\n[Simplified multi-metric IRT] Starting modeling: N={N}, J={J}, M={M}, training ratio={sample_ratio}")

    # 1. Data preprocessing (logit transformation)
    eps = 1e-5
    Y_3d_clip = np.clip(Y_3d.astype(np.float64), eps, 1 - eps)
    logit_Y_3d = logit(Y_3d_clip).astype(np.float64)
    
    if train_mask_2d is None:
        # If no training mask is provided, generate automatically
        rng = np.random.default_rng(seed)
        train_mask_2d = rng.random((N, J)) < sample_ratio
    
    # Ensure train_mask is boolean type
    train_mask_2d = train_mask_2d.astype(bool)
    
    print(f"[Preprocessing confirmation] Training sample count: {np.sum(train_mask_2d)}")

    # Check if there are enough training samples
    train_sample_count = np.sum(train_mask_2d)
    print(f"[Preprocessing confirmation] Training sample count: {train_sample_count}")
    
    if train_sample_count == 0:
        print("Warning: No training samples available!")
        # Return default prediction values
        Y_pred_mean = np.full_like(Y_3d, np.nanmean(Y_3d))
        return None, None, Y_pred_mean
    
    # If there are too few training samples, use a simplified method
    if train_sample_count < 100:  # Threshold can be adjusted as needed
        print("Warning: Too few training samples, using fallback prediction method!")
        # Use a combination of row means and column means as prediction
        Y_pred_mean = np.full_like(Y_3d, np.nanmean(Y_3d))  # Default value
        
        # Calculate row means (model means)
        row_means = np.zeros((N, M))
        for i in range(N):
            for m in range(M):
                row_train_values = Y_3d[i, train_mask_2d[i, :], m]
                if len(row_train_values) > 0:
                    row_means[i, m] = np.mean(row_train_values)
                else:
                    row_means[i, m] = np.nanmean(Y_3d[i, :, m])
        
        # Calculate column means (item means)
        col_means = np.zeros((J, M))
        for j in range(J):
            for m in range(M):
                col_train_values = Y_3d[train_mask_2d[:, j], j, m]
                if len(col_train_values) > 0:
                    col_means[j, m] = np.mean(col_train_values)
                else:
                    col_means[j, m] = np.nanmean(Y_3d[:, j, m])
        
        # Combine row means and column means as prediction
        global_mean = np.nanmean(Y_3d)
        for i in range(N):
            for j in range(J):
                for m in range(M):
                    # Use the average of row means and column means as prediction
                    if not np.isnan(row_means[i, m]) and not np.isnan(col_means[j, m]):
                        Y_pred_mean[i, j, m] = (row_means[i, m] + col_means[j, m]) / 2
                    elif not np.isnan(row_means[i, m]):
                        Y_pred_mean[i, j, m] = row_means[i, m]
                    elif not np.isnan(col_means[j, m]):
                        Y_pred_mean[i, j, m] = col_means[j, m]
                    else:
                        Y_pred_mean[i, j, m] = global_mean
                        
        return None, None, Y_pred_mean

    # 2. Build simplified multi-metric IRT model
    print("Building simple multi-metric IRT model...")
    try:
        trace, model = build_simple_multi_metric_model(N, M, J, Y_3d, train_mask_2d, 
                                                     draws=sample_kwargs['draws'], 
                                                     tune=sample_kwargs['tune'])
    except Exception as e:
        print(f"Error building model: {e}")
        # Return default prediction values
        Y_pred_mean = np.full_like(Y_3d, np.nanmean(Y_3d))
        return None, None, Y_pred_mean
    
    # 3. Extract parameter posterior means
    params_mean = {
        "psi": trace.posterior["psi"].mean(("chain", "draw")).values,
        "zeta": trace.posterior["zeta"].mean(("chain", "draw")).values,
        "a": trace.posterior["a"].mean(("chain", "draw")).values,
        "b": trace.posterior["b"].mean(("chain", "draw")).values
    }
    
    # 4. Generate predicted values
    # theta_{ijm} = psi_i + zeta_{im}
    theta = params_mean["psi"][:, None, None] + params_mean["zeta"][:, None, :]  # (N, 1, 1) + (N, 1, M) -> (N, 1, M) -> broadcast to (N, J, M)
    
    # mu_{ijm} = a_j * theta_{ijm} - b_j
    mu = params_mean["a"][None, :, None] * theta - params_mean["b"][None, :, None]  # (1, J, 1) * (N, J, M) - (1, J, 1) -> (N, J, M)
    
    # Transform logit predictions back to original scale
    Y_pred_mean = expit(mu)
    
    return trace, params_mean, Y_pred_mean

# =========================
# 3. Prediction Methods
# =========================

def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Global mean prediction (ignoring all structural information)
    """
    # Calculate the mean of all values in the training set
    global_mean = np.mean(Y[train_mask])
    # Generate a prediction matrix filled with the global mean
    pred_global = np.full_like(Y, global_mean)
    return pred_global


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Row mean prediction (mean at the model level)
    """
    pred_row = np.zeros_like(Y)
    # Calculate the training set mean for each row (model)
    for i in range(Y.shape[0]):
        row_train_values = Y[i, train_mask[i, :]]
        pred_row[i, :] = np.mean(row_train_values) if len(row_train_values) > 0 else np.nanmean(Y[i, :])
    return pred_row


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Column mean prediction (mean at the item level)
    """
    pred_col = np.zeros_like(Y)
    # Calculate the training set mean for each column (item)
    for j in range(Y.shape[1]):
        col_train_values = Y[train_mask[:, j], j]
        pred_col[:, j] = np.mean(col_train_values) if len(col_train_values) > 0 else np.nanmean(Y[:, j])
    return pred_col

# =========================
# 4. Result Saving & Visualization
# =========================

def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save IRT model MCMC trace (for subsequent analysis)
    """
    _ensure_dir(trace_dir)
    # Save in NetCDF format (ArviZ recommended format)
    trace_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    trace.to_netcdf(trace_path)
    print(f"IRT trace saved: {trace_path}")


def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_irt: np.ndarray,
    metric_idx: int = 0
):
    """
    Save sample-level prediction results (test set only)
    """
    _ensure_dir(pred_dir)
    
    # Extract test set samples
    test_rows, test_cols = np.where(test_mask)
    
    # Create prediction results DataFrame
    pred_records = []
    for i in range(len(test_rows)):
        row_idx = test_rows[i]
        col_idx = test_cols[i]
        
        # Ensure indices are within valid range
        if row_idx >= len(model_names) or col_idx >= len(question_names):
            continue
            
        record = {
            "model_idx": row_idx,
            "model": model_names[row_idx],
            "question_idx": col_idx,
            "question": question_names[col_idx],
            "train_ratio": train_subset_ratio,
            "repetition": rep + 1,
            "true_value": Y[row_idx, col_idx, metric_idx],
            "global_mean_pred": pred_global[row_idx, col_idx],
            "model_mean_pred": pred_row[row_idx, col_idx],
            "question_mean_pred": pred_col[row_idx, col_idx],
            "irt_pred": pred_irt[row_idx, col_idx]
        }
        pred_records.append(record)
    
    # Save as CSV
    pred_df = pd.DataFrame(pred_records)
    save_path = os.path.join(
        pred_dir,
        f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    pred_df.to_csv(save_path, index=False)
    print(f"Sample-level predictions saved: {save_path}")


def calculate_mse(Y: np.ndarray, Y_pred: np.ndarray, mask: np.ndarray) -> float:
    """
    Calculate MSE for masked region (test set only)
    """
    # Ensure Y and Y_pred have the same shape
    if Y.shape != Y_pred.shape:
        raise ValueError(f"Shape mismatch: Y {Y.shape} vs Y_pred {Y_pred.shape}")
    
    # Extract true values and predictions for the masked region
    y_true = Y[mask]
    y_pred = Y_pred[mask]
    
    # Check if there are valid samples
    if len(y_true) == 0 or len(y_pred) == 0:
        return np.nan
    
    # Calculate MSE (ignoring NaN values)
    mse = np.mean((y_true - y_pred) ** 2)
    return mse


def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE comparison chart
    """
    # Define method styles (colors and line types)
    method_styles = {
        "Global Mean": ("gray", "dashed"),
        "Model Mean": ("blue", "solid"),
        "Question Mean": ("green", "solid"),
        "Simple Multi-Metric IRT": ("red", "solid")
    }

    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles.get(method, ("black", "solid"))
        # Create a version without NaN values for plotting
        clean_mse_list = np.copy(mse_list)
        # Check and handle NaN values
        if np.isnan(clean_mse_list).any():
            print(f"Warning: {method} contains NaN values in MSE list")
        # Try to plot even with NaN values
        plt.plot(
            train_ratios,
            clean_mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
    plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary table to CSV (mean ± standard deviation)
    """
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"
        rows.append(row)

    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")

# =========================
# 5. Helper Functions
# =========================

def _ensure_dir(directory: str):
    """Ensure directory exists"""
    os.makedirs(directory, exist_ok=True)


def load_multiple_metrics(
    metrics_dir: str,
    metric_names: List[str],
    question_id_col: Optional[Union[str, int]] = None
) -> Tuple[np.ndarray, List[str], List[str], List[str]]:
    """
    Load data for multiple metrics and merge into a 3D matrix
    """
    all_Y = []
    common_model_names = None
    common_question_names = None
    
    for metric_name in metric_names:
        # Construct file path
        csv_path = os.path.join(metrics_dir, f"response_matrix__{metric_name}.csv")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Metric file not found: {csv_path}")
        
        # Read and preprocess data
        df = robust_read_csv(csv_path)
        _, Y, model_names, question_names = preprocess_matrix(
            df=df,
            question_id_col=question_id_col,
            zero_row_ratio_threshold=0.3,
            epsilon=1e-6
        )
        
        # Check data consistency across all metrics
        if common_model_names is None:
            common_model_names = model_names
            common_question_names = question_names
        else:
            if model_names != common_model_names:
                raise ValueError(f"Model names mismatch: {metric_name}")
            if question_names != common_question_names:
                raise ValueError(f"Question names mismatch: {metric_name}")
        
        all_Y.append(Y)
    
    # Merge into 3D matrix (N, J, M)
    Y_3d = np.stack(all_Y, axis=-1)
    
    return Y_3d, common_model_names, common_question_names, metric_names

# =========================
# 6. Main Experiment Pipeline
# =========================



def run_simple_multimetric_experiment(
    metrics_dir: str,
    output_root_dir: str,
    selected_metrics: List[str],
    main_metric_idx: int = 0,  # Main metric index
    question_id_col: Optional[Union[str, int]] = 0,
    test_ratio: float = 0.10,
    test_seed: int = 42,
    train_ratios: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    rep_count: int = 1,
    mcmc_draws: int = 100,
    mcmc_tune: int = 100,
    mcmc_chains: int = 12,
    mcmc_cores: Optional[int] = None
):
    """
    Simplified multi-metric joint modeling experiment main pipeline
    """
    print(f"Running simple multi-metric experiment with metrics: {selected_metrics}")
    print(f"Main metric index: {main_metric_idx}")
    
    # 1. Load multi-metric data
    print("\n=== Step 1/6: Loading Multiple Metrics ===")
    Y_3d, model_names, key_names, metric_names = load_multiple_metrics(
        metrics_dir=metrics_dir,
        metric_names=selected_metrics,
        question_id_col=question_id_col
    )
    
    N, J, M = Y_3d.shape
    print(f"Loaded 3D data shape: {N} models × {J} questions × {M} metrics")
    
    # 2. Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
        "params": os.path.join(output_root_dir, "05_params")
    }
    for dir_path in output_dirs.values():
        _ensure_dir(dir_path)
    
    # 3. Split fixed test set
    print("\n=== Step 2/6: Split Fixed Test Set ===")
    # Use the shape of the first metric for splitting (all metrics share the same test set)
    test_mask_2d, remaining_mask_2d = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )
    
    # 4. Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios},
        "Simple Multi-Metric IRT": {r: [] for r in train_ratios}
    }
    
    # 5. MCMC sampling parameters
    sample_kwargs = {
        'draws': mcmc_draws,
        'tune': mcmc_tune,
        'chains': mcmc_chains,
        'cores': mcmc_cores or mcmc_chains
    }
    
    # 6. Iterate through all training ratios + repetition counts
    for train_ratio in train_ratios:
        print(f"\n=== Step 3/6: Train Ratio = {train_ratio:.3f} ===")
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep
            
            # 6.1 Sample training set for current repetition
            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask_2d = sample_training_subset(
                remaining_mask=remaining_mask_2d,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            print(f"Training samples: {train_mask_2d.sum()}, Test samples: {test_mask_2d.sum()}")
            
            # 6.2 Calculate baseline predictions for main metric
            Y_main = Y_3d[:, :, main_metric_idx]
            pred_global = predict_global_mean(Y_main, train_mask_2d)
            pred_row = predict_row_mean(Y_main, train_mask_2d)
            pred_col = predict_col_mean(Y_main, train_mask_2d)
            
            # 6.3 Run simplified multi-metric IRT model
            print("Running simple multi-metric IRT model...")
            trace, params_mean, Y_pred_irt_3d = run_simple_multimetric_irt(
                Y_3d=Y_3d,
                sample_ratio=train_ratio,
                output_dir=output_dirs["params"],
                metric_names=metric_names,
                model_names=model_names,
                key_names=key_names,
                sample_kwargs=sample_kwargs,
                train_mask_2d=train_mask_2d,
                seed=rep_seed
            )
            
            # Save IRT trace to file
            if trace is not None:
                save_irt_trace(
                    trace_dir=output_dirs["irt_trace"],
                    trace=trace,
                    model_type="simple_multimetric_irt",
                    train_subset_ratio=train_ratio,
                    rep=rep
                )
                
                # Save zeta posterior variance to CSV file
                save_zeta_posterior_variance(
                    trace=trace,
                    model_names=model_names,
                    output_dir=output_dirs["params"],
                    train_subset_ratio=train_ratio,
                    rep=rep
                )
                
                # Save all parameter posterior estimates to CSV file
                save_parameter_posterior_estimates(
                    trace=trace,
                    model_names=model_names,
                    question_names=key_names,
                    output_dir=output_dirs["params"],
                    train_subset_ratio=train_ratio,
                    rep=rep
                )
            
            # 6.4 Calculate MSE and save results
            # Calculate MSE for main metric only
            Y_pred_irt_main = Y_pred_irt_3d[:, :, main_metric_idx]
            mse_global = calculate_mse(Y_main, pred_global, test_mask_2d)
            mse_row = calculate_mse(Y_main, pred_row, test_mask_2d)
            mse_col = calculate_mse(Y_main, pred_col, test_mask_2d)
            mse_irt = calculate_mse(Y_main, Y_pred_irt_main, test_mask_2d)
            
            # Record MSE
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            mse_summary["Simple Multi-Metric IRT"][train_ratio].append(mse_irt)
            
            # Save sample-level predictions
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y_3d,
                test_mask=test_mask_2d,
                model_names=model_names,
                question_names=key_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_irt=Y_pred_irt_main
            )
            
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            print(f"Simple Multi-Metric IRT: {mse_irt:.6f}")
            
            # 6.5 Record data split
            if rep == 0:  # Avoid duplicate records
                record_data_split(
                    split_dir=output_dirs["split"],
                    test_mask=test_mask_2d,
                    train_masks=[train_mask_2d],
                    train_subset_ratio=train_ratio,
                    model_names=model_names,
                    question_names=key_names
                )
    
    # 7. Generate final MSE summary and visualization
    print("\n=== Step 4/6: Generate Final Results ===")
    
    # Save MSE summary table
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)
    
    # Generate MSE comparison plot
    plot_mse_dict = {}
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    plot_mse_dict["Simple Multi-Metric IRT"] = [np.mean(mse_summary["Simple Multi-Metric IRT"][r]) for r in train_ratios]
    
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)
    
    # 8. 保存模型参数
    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")


def save_parameter_posterior_estimates(
    trace: az.InferenceData,
    model_names: List[str],
    question_names: List[str],
    output_dir: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save all parameter posterior estimates to CSV file
    
    Parameters:
    - trace: MCMC sampling results
    - model_names: List of model names
    - question_names: List of question names
    - output_dir: Output directory
    - train_subset_ratio: Training subset ratio
    - rep: Repetition index
    """
    _ensure_dir(output_dir)
    
    # Get posterior distribution means
    posterior_mean = trace.posterior.mean(dim=["chain", "draw"])
    
    # Save psi parameter (model global ability)
    psi_mean = posterior_mean["psi"]  # shape: (N,)
    psi_records = []
    for i in range(len(psi_mean)):
        psi_records.append({
            "model_idx": i,
            "model": model_names[i] if i < len(model_names) else f"model_{i}",
            "psi_hat": float(psi_mean[i].values)
        })
    
    psi_df = pd.DataFrame(psi_records)
    psi_path = os.path.join(
        output_dir,
        f"psi_posterior_estimates_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    psi_df.to_csv(psi_path, index=False)
    print(f"Psi posterior estimates saved: {psi_path}")
    
    # Save zeta parameter (model metric-specific ability)
    zeta_mean = posterior_mean["zeta"]  # shape: (N, M)
    zeta_records = []
    for i in range(zeta_mean.shape[0]):
        for m in range(zeta_mean.shape[1]):
            zeta_records.append({
                "model_idx": i,
                "model": model_names[i] if i < len(model_names) else f"model_{i}",
                "metric_idx": m,
                "zeta_hat": float(zeta_mean[i, m].values)
            })
    
    zeta_df = pd.DataFrame(zeta_records)
    zeta_path = os.path.join(
        output_dir,
        f"zeta_posterior_estimates_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    zeta_df.to_csv(zeta_path, index=False)
    print(f"Zeta posterior estimates saved: {zeta_path}")
    
    # Save a parameter (item discrimination)
    a_mean = posterior_mean["a"]  # shape: (J,)
    a_records = []
    for j in range(len(a_mean)):
        a_records.append({
            "question_idx": j,
            "question": question_names[j] if j < len(question_names) else f"question_{j}",
            "a_hat": float(a_mean[j].values)
        })
    
    a_df = pd.DataFrame(a_records)
    a_path = os.path.join(
        output_dir,
        f"a_posterior_estimates_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    a_df.to_csv(a_path, index=False)
    print(f"A posterior estimates saved: {a_path}")
    
    # Save b parameter (item difficulty)
    b_mean = posterior_mean["b"]  # shape: (J,)
    b_records = []
    for j in range(len(b_mean)):
        b_records.append({
            "question_idx": j,
            "question": question_names[j] if j < len(question_names) else f"question_{j}",
            "b_hat": float(b_mean[j].values)
        })
    
    b_df = pd.DataFrame(b_records)
    b_path = os.path.join(
        output_dir,
        f"b_posterior_estimates_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    b_df.to_csv(b_path, index=False)
    print(f"B posterior estimates saved: {b_path}")



def save_zeta_posterior_variance(
    trace: az.InferenceData,
    model_names: List[str],
    output_dir: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save zeta posterior variance for each model to CSV file
    
    Parameters:
    - trace: MCMC sampling results
    - model_names: List of model names
    - output_dir: Output directory
    - train_subset_ratio: Training subset ratio
    - rep: Repetition index
    """
    _ensure_dir(output_dir)
    
    # Get zeta posterior distribution
    zeta_posterior = trace.posterior["zeta"]  # shape: (chain, draw, N, M)
    
    # Calculate posterior variance (across chain and draw dimensions)
    zeta_variance = zeta_posterior.var(dim=["chain", "draw"])  # shape: (N, M)
    
    # Convert to DataFrame format
    N, M = zeta_variance.shape
    
    # Create records list
    records = []
    for i in range(N):
        for m in range(M):
            records.append({
                "model_idx": i,
                "model": model_names[i] if i < len(model_names) else f"model_{i}",
                "metric_idx": m,
                "posterior_variance": float(zeta_variance[i, m].values)
            })
    
    # Save as CSV
    df = pd.DataFrame(records)
    save_path = os.path.join(
        output_dir,
        f"zeta_posterior_variance_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    df.to_csv(save_path, index=False)
    print(f"Zeta posterior variance saved: {save_path}")

# =========================
# 7. Configuration & Execution
# =========================

if __name__ == "__main__":
    # Global configuration parameters
    METRICS_DIR = "data/continuous_processed"  # Multi-metric data directory
    OUTPUT_ROOT_DIR = "results/modeling_result_simple_approach_multimetric_bertscore_f1_std"  # Output root directory
    # SELECTED_METRICS = ["bertscore_F1","bertscore_P","bertscore_R","bleu","meteor","rouge1","rouge2","rougeL"]  # Selected metrics to run
    SELECTED_METRICS = ["bertscore_F1","bertscore_P","bertscore_R","bleu","meteor","rouge1","rougeL"]
    MAIN_METRIC_IDX = 0  # Main metric index (bertscore_F1)
    QUESTION_ID_COL = 'row_index'  # Question ID column
    TEST_SEED = 42  # Fixed test set seed
    TEST_RATIO = 0.10  # Test set ratio
    TRAIN_RATIOS = [1.0]  # Training set ratio list
    REP_COUNT = 1  # Repetition count for each training ratio
    MCMC_DRAWS = 10  # MCMC samples per chain
    MCMC_TUNE = 10  # MCMC warmup iterations
    MCMC_CHAINS = 4  # MCMC parallel chains
    MCMC_CORES = 4  # MCMC parallel cores
    
    # Run simplified multi-metric experiment
    run_simple_multimetric_experiment(
        metrics_dir=METRICS_DIR,
        output_root_dir=OUTPUT_ROOT_DIR,
        selected_metrics=SELECTED_METRICS,
        main_metric_idx=MAIN_METRIC_IDX,
        question_id_col=QUESTION_ID_COL,
        test_ratio=TEST_RATIO,
        test_seed=TEST_SEED,
        train_ratios=TRAIN_RATIOS,
        rep_count=REP_COUNT,
        mcmc_draws=MCMC_DRAWS,
        mcmc_tune=MCMC_TUNE,
        mcmc_chains=MCMC_CHAINS,
        mcmc_cores=MCMC_CORES
    )

