#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
from typing import Dict, List, Optional, Tuple, Union
from scipy.special import logit, expit

# Don't use Chinese fonts, use default font
plt.rcParams["axes.unicode_minus"] = False  # Correctly display minus sign

# =========================
# 1. I/O & Preprocessing
# =========================

def robust_read_csv(file_path: str) -> pd.DataFrame:
    """
    Robust CSV reading function (supports different encodings and delimiters)
    """
    try:
        # Try UTF-8 encoding first
        return pd.read_csv(file_path)
    except UnicodeDecodeError:
        try:
            # Try GBK encoding
            return pd.read_csv(file_path, encoding="gbk")
        except Exception as e:
            raise IOError(f"Cannot read file {file_path}: {str(e)}")


def preprocess_matrix(
    df: pd.DataFrame,
    question_id_col: Optional[Union[str, int]] = None,
    zero_row_ratio_threshold: float = 0.3,
    epsilon: float = 1e-6
) -> Tuple[pd.DataFrame, np.ndarray, List[str], List[str]]:
    """
    Preprocess matrix data: remove outliers, normalize, get model names and question names
    """
    # 1. Remove question ID column (if exists)
    if question_id_col is not None:
        df_processed = df.drop(columns=[question_id_col], errors="ignore")
    else:
        df_processed = df.copy()
    
    # 2. Get model names (column names) and question names (index)
    model_names = list(df_processed.columns)
    question_names = list(df_processed.index)
    
    # 3. Convert to numpy array
    Y = df_processed.values.astype(np.float64)
    
    # 4. Handle outliers (replace 0 with tiny value to avoid logit conversion issues)
    Y = np.clip(Y, epsilon, 1 - epsilon)
    
    return df_processed, Y, model_names, question_names


def split_fixed_test_set(
    Y_shape: Tuple[int, int],
    test_ratio: float = 0.1,
    test_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split fixed test set (ensure consistent test set across repeated experiments)
    """
    np.random.seed(test_seed)
    N, J = Y_shape
    # Generate random mask (True=test set, False=non-test set)
    test_mask = np.random.random((N, J)) < test_ratio
    # Remaining mask (non-test set)
    remaining_mask = ~test_mask
    
    return test_mask, remaining_mask


def sample_training_subset(
    remaining_mask: np.ndarray,
    train_subset_ratio: float,
    rep_seed: int
) -> np.ndarray:
    """
    Sample training subset from remaining data (outside fixed test set)
    """
    np.random.seed(rep_seed)
    # Generate random mask (only effective where remaining_mask=True)
    train_subset_mask = np.random.random(remaining_mask.shape) < train_subset_ratio
    # Training mask = remaining mask & training subset mask
    train_mask = remaining_mask & train_subset_mask
    
    return train_mask


def record_data_split(
    split_dir: str,
    test_mask: np.ndarray,
    train_masks: List[np.ndarray],
    train_subset_ratio: float,
    model_names: List[str],
    question_names: List[str]
):
    """
    Record data split information (test set and training set for each repetition)
    """
    for rep, train_mask in enumerate(train_masks):
        # Merge test set and training set masks (1=test set, 2=training set, 0=unused)
        split_mask = np.zeros_like(test_mask)
        split_mask[test_mask] = 1  # Test set
        split_mask[train_mask] = 2  # Training set
        
        # Create DataFrame and save
        split_df = pd.DataFrame(split_mask, index=question_names, columns=model_names)
        save_path = os.path.join(split_dir, f"split_ratio_{train_subset_ratio:.1f}_rep{rep+1}.csv")
        split_df.to_csv(save_path)
    
    print(f"Data splits saved to {split_dir}")

# =========================
# 2. Simple Multi-Metric IRT Model
# =========================

def run_simple_multimetric_irt(
    Y_main: np.ndarray,
    Y_aux_list: List[np.ndarray],
    sample_ratio: float,
    output_dir: str,
    model_names: list,
    key_names: list,
    sample_kwargs: dict,
    seed: int = 20240619,
    train_mask: np.ndarray | None = None,
):
    """
    Simplified multi-metric IRT modeling
    Use main metric for IRT modeling, auxiliary metrics as prior information
    """
    N, J = Y_main.shape
    M = len(Y_aux_list)  # Number of auxiliary metrics
    _ensure_dir(output_dir)
    print(f"\n[Simplified Multi-Metric IRT] Starting modeling: N={N}, J={J}, auxiliary metrics={M}, training ratio={sample_ratio}")

    # 1. Data preprocessing (logit transformation)
    eps = 1e-5
    Y_main_clip = np.clip(Y_main.astype(np.float64), eps, 1 - eps)
    logit_Y_main = logit(Y_main_clip).astype(np.float64)
    
    # Process auxiliary metrics
    logit_Y_aux_list = []
    for Y_aux in Y_aux_list:
        Y_aux_clip = np.clip(Y_aux.astype(np.float64), eps, 1 - eps)
        logit_Y_aux = logit(Y_aux_clip).astype(np.float64)
        logit_Y_aux_list.append(logit_Y_aux)
    
    if train_mask is None:
        # If no training mask provided, generate automatically
        rng = np.random.default_rng(seed)
        train_mask = rng.random((N, J)) < sample_ratio
    
    # Ensure train_mask is boolean type
    train_mask = train_mask.astype(bool)
    
    logit_train_main = np.where(train_mask, logit_Y_main, np.nan).astype(np.float64)
    obs_mask = ~np.isnan(logit_train_main)
    
    print(f"[Preprocessing confirmation] Non-NaN training samples: {np.sum(obs_mask)}")

    # Check if there are enough training samples
    if np.sum(obs_mask) == 0:
        print("Warning: No training samples available!")
        # Return default prediction values
        Y_pred_mean = np.full_like(Y_main, np.nanmean(Y_main))
        return None, {"theta": np.zeros(N), "b": np.zeros(J)}, Y_pred_mean

    # 2. Use auxiliary metrics to calculate prior information
    # Calculate mean of auxiliary metrics as prior
    aux_means = []
    for Y_aux in Y_aux_list:
        aux_means.append(np.nanmean(Y_aux))
    
    # 3. Build simplified IRT model (based on main metric)
    with pm.Model() as irt_model:
        # Model ability parameters (use auxiliary metric means as prior means)
        if aux_means:
            prior_mean = np.mean(aux_means)
            theta = pm.Normal("theta", mu=prior_mean, sigma=1.0, shape=N)
        else:
            theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)
        
        # Question difficulty parameters
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)
        
        # Linear predictor
        mu = theta[:, None] - b[None, :]
        
        # Observation model
        pm.Normal(
            "y_obs", mu=mu[obs_mask], sigma=1.0,
            observed=logit_train_main[obs_mask]
        )
        
        # 4. MCMC sampling
        print(f"[MCMC Sampling] Parameters: {sample_kwargs}")
        trace = pm.sample(
            **sample_kwargs,
            random_seed=seed,
            target_accept=0.95,
            progressbar=True,
            return_inferencedata=True
        )

    # 5. Extract posterior mean of parameters
    params_mean = {
        "theta": trace.posterior["theta"].mean(("chain", "draw")).values,
        "b": trace.posterior["b"].mean(("chain", "draw")).values
    }
    
    # 6. Generate predictions
    theta_mean = params_mean["theta"].reshape(-1, 1)
    b_mean = params_mean["b"].reshape(1, -1)
    mu_pred = theta_mean - b_mean
    Y_pred_mean = expit(mu_pred)
    
    return trace, params_mean, Y_pred_mean

# =========================
# 3. Prediction Methods
# =========================

def predict_global_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Global mean prediction (ignoring all structural information)
    """
    # Calculate mean of all non-zero values in training set
    global_mean = np.mean(Y[train_mask])
    # Generate prediction matrix full of global mean
    pred_global = np.full_like(Y, global_mean)
    return pred_global


def predict_row_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Row mean prediction (model-level mean)
    """
    pred_row = np.zeros_like(Y)
    # Calculate training set mean for each row (model)
    for i in range(Y.shape[0]):
        row_train_values = Y[i, train_mask[i, :]]
        pred_row[i, :] = np.mean(row_train_values) if len(row_train_values) > 0 else np.nanmean(Y[i, :])
    return pred_row


def predict_col_mean(Y: np.ndarray, train_mask: np.ndarray) -> np.ndarray:
    """
    Column mean prediction (question-level mean)
    """
    pred_col = np.zeros_like(Y)
    # Calculate training set mean for each column (question)
    for j in range(Y.shape[1]):
        col_train_values = Y[train_mask[:, j], j]
        pred_col[:, j] = np.mean(col_train_values) if len(col_train_values) > 0 else np.nanmean(Y[:, j])
    return pred_col

# =========================
# 4. Result Saving & Visualization
# =========================

def save_irt_trace(
    trace_dir: str,
    trace: az.InferenceData,
    model_type: str,
    train_subset_ratio: float,
    rep: int
):
    """
    Save IRT model MCMC trace (for subsequent analysis)
    """
    _ensure_dir(trace_dir)
    # Save in NetCDF format (recommended by ArviZ)
    trace_path = os.path.join(
        trace_dir,
        f"irt_{model_type}_ratio_{train_subset_ratio:.3f}_rep{rep+1}.nc"
    )
    trace.to_netcdf(trace_path)
    print(f"IRT trace saved: {trace_path}")


def save_sample_level_predictions(
    pred_dir: str,
    Y: np.ndarray,
    test_mask: np.ndarray,
    model_names: List[str],
    question_names: List[str],
    train_subset_ratio: float,
    rep: int,
    pred_global: np.ndarray,
    pred_row: np.ndarray,
    pred_col: np.ndarray,
    pred_irt: np.ndarray,
):
    """
    Save sample-level prediction results (test set only)
    """
    _ensure_dir(pred_dir)
    
    # Extract test set samples
    test_rows, test_cols = np.where(test_mask)
    
    # Create prediction results DataFrame
    pred_records = []
    for i in range(len(test_rows)):
        row_idx = test_rows[i]
        col_idx = test_cols[i]
        
        # Ensure indices are within valid range
        if row_idx >= len(model_names) or col_idx >= len(question_names):
            continue
            
        record = {
            "model_idx": row_idx,
            "model": model_names[row_idx],
            "question_idx": col_idx,
            "question": question_names[col_idx],
            "train_ratio": train_subset_ratio,
            "repetition": rep + 1,
            "true_value": Y[row_idx, col_idx],
            "global_mean_pred": pred_global[row_idx, col_idx],
            "model_mean_pred": pred_row[row_idx, col_idx],
            "question_mean_pred": pred_col[row_idx, col_idx],
            "irt_pred": pred_irt[row_idx, col_idx]
        }
        pred_records.append(record)
    
    # Save as CSV
    pred_df = pd.DataFrame(pred_records)
    save_path = os.path.join(
        pred_dir,
        f"predictions_ratio_{train_subset_ratio:.3f}_rep{rep+1}.csv"
    )
    pred_df.to_csv(save_path, index=False)
    print(f"Sample-level predictions saved: {save_path}")


def calculate_mse(Y: np.ndarray, Y_pred: np.ndarray, mask: np.ndarray) -> float:
    """
    Calculate MSE for masked region (test set only)
    """
    # Ensure Y and Y_pred have the same shape
    if Y.shape != Y_pred.shape:
        raise ValueError(f"Shape mismatch: Y {Y.shape} vs Y_pred {Y_pred.shape}")
    
    # Extract true values and predictions for masked region
    y_true = Y[mask]
    y_pred = Y_pred[mask]
    
    # Check if there are valid samples
    if len(y_true) == 0 or len(y_pred) == 0:
        return np.nan
    
    # Calculate MSE (ignoring NaN values)
    mse = np.mean((y_true - y_pred) ** 2)
    return mse


def plot_mse_comparison(
    mse_dict: Dict[str, List[float]],
    train_ratios: List[float],
    save_path: str
):
    """
    Plot MSE comparison chart
    """
    # Define method styles (colors and line types)
    method_styles = {
        "Global Mean": ("gray", "dashed"),
        "Model Mean": ("blue", "solid"),
        "Question Mean": ("green", "solid"),
        "Simple Multi-Metric IRT": ("red", "solid")
    }

    plt.figure(figsize=(10, 6))
    for method, mse_list in mse_dict.items():
        color, linestyle = method_styles.get(method, ("black", "solid"))
        # Create a version without NaN values for plotting
        clean_mse_list = np.copy(mse_list)
        # Check and handle NaN values
        if np.isnan(clean_mse_list).any():
            print(f"Warning: {method} contains NaN values in MSE list")
        # Try to plot even with NaN values
        plt.plot(
            train_ratios,
            clean_mse_list,
            label=method,
            color=color,
            linestyle=linestyle,
            linewidth=2,
            marker="o",
            markersize=4
        )

    plt.xlabel("Training Data Ratio (Fraction of Remaining Pool)", fontsize=12)
    plt.ylabel("Test Set MSE (Lower = Better)", fontsize=12)
    plt.title("MSE vs. Training Data Ratio for All Prediction Methods", fontsize=14, pad=20)
    plt.legend(loc="upper right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"MSE plot saved: {save_path}")


def save_mse_summary(
    mse_summary: Dict[str, Dict[float, List[float]]],
    save_path: str
):
    """
    Save MSE summary table to CSV (mean ± standard deviation)
    """
    rows = []
    for ratio in sorted(mse_summary["Global Mean"].keys()):
        row = {"Train_Ratio": ratio}
        for method in mse_summary.keys():
            mse_list = mse_summary[method][ratio]
            mse_mean = np.mean(mse_list)
            mse_std = np.std(mse_list)
            row[method] = f"{mse_mean:.6f} ± {mse_std:.6f}"
        rows.append(row)

    pd.DataFrame(rows).to_csv(save_path, index=False)
    print(f"MSE summary saved: {save_path}")

# =========================
# 5. Helper Functions
# =========================

def _ensure_dir(directory: str):
    """Ensure directory exists"""
    os.makedirs(directory, exist_ok=True)


def load_multiple_metrics(
    metrics_dir: str,
    metric_names: List[str],
    question_id_col: Optional[Union[str, int]] = None
) -> Tuple[List[np.ndarray], List[str], List[str], List[str]]:
    """
    Load data for multiple metrics
    """
    all_Y = []
    common_model_names = None
    common_question_names = None
    
    for metric_name in metric_names:
        # Construct file path
        csv_path = os.path.join(metrics_dir, f"response_matrix__{metric_name}.csv")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Metric file not found: {csv_path}")
        
        # Read and preprocess data
        df = robust_read_csv(csv_path)
        _, Y, model_names, question_names = preprocess_matrix(
            df=df,
            question_id_col=question_id_col,
            zero_row_ratio_threshold=0.3,
            epsilon=1e-6
        )
        
        # Check data consistency across all metrics
        if common_model_names is None:
            common_model_names = model_names
            common_question_names = question_names
        else:
            if model_names != common_model_names:
                raise ValueError(f"Inconsistent model names: {metric_name}")
            if question_names != common_question_names:
                raise ValueError(f"Inconsistent question names: {metric_name}")
        
        all_Y.append(Y)
    
    return all_Y, common_model_names, common_question_names, metric_names

# =========================
# 6. Main Experiment Pipeline
# =========================

def run_simple_multimetric_experiment(
    metrics_dir: str,
    output_root_dir: str,
    selected_metrics: List[str],
    main_metric_idx: int = 0,  # Main metric index
    question_id_col: Optional[Union[str, int]] = 0,
    test_ratio: float = 0.10,
    test_seed: int = 42,
    train_ratios: List[float] = [0.2, 0.5, 1.0],
    rep_count: int = 1,
    mcmc_draws: int = 100,
    mcmc_tune: int = 100,
    mcmc_chains: int = 4,
    mcmc_cores: Optional[int] = None
):
    """
    Main pipeline for simplified multi-metric joint modeling experiment
    """
    print(f"Running simple multi-metric experiment with metrics: {selected_metrics}")
    print(f"Main metric index: {main_metric_idx}")
    
    # 1. Load multiple metrics data
    print("\n=== Step 1/6: Loading Multiple Metrics ===")
    Y_list, model_names, key_names, metric_names = load_multiple_metrics(
        metrics_dir=metrics_dir,
        metric_names=selected_metrics,
        question_id_col=question_id_col
    )
    
    # Determine main metric and auxiliary metrics
    Y_main = Y_list[main_metric_idx]
    Y_aux_list = [Y for i, Y in enumerate(Y_list) if i != main_metric_idx]
    
    N, J = Y_main.shape
    M = len(Y_aux_list)
    print(f"Main metric shape: {N} models × {J} questions")
    print(f"Auxiliary metrics count: {M}")
    
    # 2. Create output directories
    output_dirs = {
        "split": os.path.join(output_root_dir, "01_data_split"),
        "predictions": os.path.join(output_root_dir, "02_sample_predictions"),
        "irt_trace": os.path.join(output_root_dir, "03_irt_traces"),
        "metrics": os.path.join(output_root_dir, "04_metrics"),
        "params": os.path.join(output_root_dir, "05_params")
    }
    for dir_path in output_dirs.values():
        _ensure_dir(dir_path)
    
    # 3. Split fixed test set
    print("\n=== Step 2/6: Split Fixed Test Set ===")
    test_mask, remaining_mask = split_fixed_test_set(
        Y_shape=(N, J),
        test_ratio=test_ratio,
        test_seed=test_seed
    )
    
    # 4. Initialize MSE record dictionary
    mse_summary = {
        "Global Mean": {r: [] for r in train_ratios},
        "Model Mean": {r: [] for r in train_ratios},
        "Question Mean": {r: [] for r in train_ratios},
        "Simple Multi-Metric IRT": {r: [] for r in train_ratios}
    }
    
    # 5. MCMC sampling parameters
    sample_kwargs = {
        'draws': mcmc_draws,
        'tune': mcmc_tune,
        'chains': mcmc_chains,
        'cores': mcmc_cores or mcmc_chains
    }
    
    # 6. Iterate through all training ratios + repetitions
    for train_ratio in train_ratios:
        print(f"\n=== Step 3/6: Train Ratio = {train_ratio:.3f} ===")
        
        for rep in range(rep_count):
            print(f"\n--- Repetition {rep+1}/{rep_count} ---")
            rep_seed = test_seed + rep
            
            # 6.1 Sample training subset for current repetition
            print(f"Sampling training subset (ratio={train_ratio:.3f})...")
            train_mask = sample_training_subset(
                remaining_mask=remaining_mask,
                train_subset_ratio=train_ratio,
                rep_seed=rep_seed
            )
            print(f"Training samples: {train_mask.sum()}, Test samples: {test_mask.sum()}")
            
            # 6.2 Calculate baseline predictions for main metric
            pred_global = predict_global_mean(Y_main, train_mask)
            pred_row = predict_row_mean(Y_main, train_mask)
            pred_col = predict_col_mean(Y_main, train_mask)
            
            # 6.3 Run simplified multi-metric IRT model
            print("Running simple multi-metric IRT model...")
            trace, params_mean, Y_pred_irt = run_simple_multimetric_irt(
                Y_main=Y_main,
                Y_aux_list=Y_aux_list,
                sample_ratio=train_ratio,
                output_dir=output_dirs["params"],
                model_names=model_names,
                key_names=key_names,
                sample_kwargs=sample_kwargs,
                train_mask=train_mask,
                seed=rep_seed
            )
            
            # Save IRT trace to file
            if trace is not None:
                save_irt_trace(
                    trace_dir=output_dirs["irt_trace"],
                    trace=trace,
                    model_type="simple_multimetric_irt",
                    train_subset_ratio=train_ratio,
                    rep=rep
                )
            
            # 6.4 Calculate MSE and save results
            mse_global = calculate_mse(Y_main, pred_global, test_mask)
            mse_row = calculate_mse(Y_main, pred_row, test_mask)
            mse_col = calculate_mse(Y_main, pred_col, test_mask)
            mse_irt = calculate_mse(Y_main, Y_pred_irt, test_mask)
            
            # Record MSE
            mse_summary["Global Mean"][train_ratio].append(mse_global)
            mse_summary["Model Mean"][train_ratio].append(mse_row)
            mse_summary["Question Mean"][train_ratio].append(mse_col)
            mse_summary["Simple Multi-Metric IRT"][train_ratio].append(mse_irt)
            
            # Save sample-level predictions
            save_sample_level_predictions(
                pred_dir=output_dirs["predictions"],
                Y=Y_main,
                test_mask=test_mask,
                model_names=model_names,
                question_names=key_names,
                train_subset_ratio=train_ratio,
                rep=rep,
                pred_global=pred_global,
                pred_row=pred_row,
                pred_col=pred_col,
                pred_irt=Y_pred_irt
            )
            
            print(f"\nRep {rep+1} MSE Results:")
            print(f"Baseline: Global={mse_global:.6f} | Model={mse_row:.6f} | Question={mse_col:.6f}")
            print(f"Simple Multi-Metric IRT: {mse_irt:.6f}")
            
            # 6.5 Record data split
            if rep == 0:  # Avoid duplicate recording
                record_data_split(
                    split_dir=output_dirs["split"],
                    test_mask=test_mask,
                    train_masks=[train_mask],
                    train_subset_ratio=train_ratio,
                    model_names=model_names,
                    question_names=key_names
                )
    
    # 7. Generate final MSE summary and visualization
    print("\n=== Step 4/6: Generate Final Results ===")
    
    # Save MSE summary table
    mse_summary_path = os.path.join(output_dirs["metrics"], "mse_summary.csv")
    save_mse_summary(mse_summary, mse_summary_path)
    
    # Generate MSE comparison plot
    plot_mse_dict = {}
    plot_mse_dict["Global Mean"] = [np.mean(mse_summary["Global Mean"][r]) for r in train_ratios]
    plot_mse_dict["Model Mean"] = [np.mean(mse_summary["Model Mean"][r]) for r in train_ratios]
    plot_mse_dict["Question Mean"] = [np.mean(mse_summary["Question Mean"][r]) for r in train_ratios]
    plot_mse_dict["Simple Multi-Metric IRT"] = [np.mean(mse_summary["Simple Multi-Metric IRT"][r]) for r in train_ratios]
    
    mse_plot_path = os.path.join(output_dirs["metrics"], "mse_vs_train_ratio.png")
    plot_mse_comparison(plot_mse_dict, train_ratios, mse_plot_path)
    
    # 8. Save model parameters
    print("\n=== Experiment Completed ===")
    print(f"All results saved to: {output_root_dir}")

# =========================
# 7. Configuration & Execution
# =========================

if __name__ == "__main__":
    # Global configuration parameters
    METRICS_DIR = "data/continuous_processed"  # Multi-metric data directory
    OUTPUT_ROOT_DIR = "results/modeling_result_simple_multimetric_bertscore_f1"  # Output root directory
    SELECTED_METRICS = ["bertscore_F1", "bleu", "rouge1"]  # Selected metrics to run
    MAIN_METRIC_IDX = 0  # Main metric index (bertscore_F1)
    QUESTION_ID_COL = 'row_index'  # Question ID column
    TEST_SEED = 42  # Fixed test set seed
    TEST_RATIO = 0.10  # Test set ratio
    TRAIN_RATIOS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]  # Training ratio list
    REP_COUNT = 1  # Repetitions per training ratio
    MCMC_DRAWS = 100  # MCMC samples per chain
    MCMC_TUNE = 100  # MCMC warmup iterations
    MCMC_CHAINS = 4  # MCMC parallel chains
    MCMC_CORES = 4  # MCMC parallel cores
    
    # Run simplified multi-metric experiment
    run_simple_multimetric_experiment(
        metrics_dir=METRICS_DIR,
        output_root_dir=OUTPUT_ROOT_DIR,
        selected_metrics=SELECTED_METRICS,
        main_metric_idx=MAIN_METRIC_IDX,
        question_id_col=QUESTION_ID_COL,
        test_ratio=TEST_RATIO,
        test_seed=TEST_SEED,
        train_ratios=TRAIN_RATIOS,
        rep_count=REP_COUNT,
        mcmc_draws=MCMC_DRAWS,
        mcmc_tune=MCMC_TUNE,
        mcmc_chains=MCMC_CHAINS,
        mcmc_cores=MCMC_CORES
    )