import argparse
import math
import os
import random
from datetime import datetime  # Added for timestamp
from functools import partial  # Added for wrapping
import json

import numpy as np
import pandas as pd
from joblib import Parallel, delayed  # Added for parallel processing
from tqdm import tqdm
from scipy.special import softmax

# Assuming rescale_residual (RR) and load_uci_dataset are in the path
import rescale_residual as RR
from data_utils import load_uci_dataset  # Make sure this path is correct

from sklearn.ensemble import (
    AdaBoostRegressor,
    GradientBoostingRegressor,
    RandomForestRegressor,
)
from sklearn.linear_model import ElasticNet, LinearRegression
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from k_means_constrained import KMeansConstrained

ETA_PRIOR = 1


# =======================================================================
# Helper Functions
# =======================================================================
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)


# =======================================================================
# Wrapper Definitions (from synthetic_exp.py)
# =======================================================================
def wrap_existing_conformal(base_func, split_portion=None):
    """Wrapper: Wraps prior-agnostic methods"""
    # Correctly use partial for the split_portion argument
    if split_portion is not None:
        wrapped_base_func = partial(base_func, split_portion=split_portion)
    else:
        wrapped_base_func = base_func

    def final_configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        # b_prior is unused by these methods but kept for consistent signature
        return wrapped_base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha=alpha,
        )

    return final_configured_func


def wrap_stable_conformal(base_func, eta=1.0, ignore_prior=False):
    """Wrapper: Wraps RR.stable_conformal"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior
        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha=alpha,
            eta=eta,
            b_prior=prior_to_use,
        )

    return configured_func


def wrap_adaptive_stable_conformal(base_func, ratio=0.75, ignore_prior=False):
    """Wrapper: Wraps RR.adaptive_stable_conformal"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior
        alpha_prime = alpha * ratio
        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha_prime=alpha_prime,
            alpha=alpha,
            b_prior=prior_to_use,
        )

    return configured_func


def wrap_internal_split_calibration(
    base_func,
    alpha_pre_selection=0.1,  # Default from synthetic
    alpha_post_selection=0.75,  # Default from synthetic
    N_resamples=1,  # Default from synthetic args
    preliminary_gamma=0.9,  # Default from synthetic (1-alpha)
    aux_split_ratio=0.5,  # Default from synthetic
    ignore_prior=False,
):
    """Wrapper for calibrate_after_selection_resampling"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior

        # Note: 'alpha' passed to this wrapper is the *final target alpha* for the method
        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha=alpha,  # This is the final alpha for coverage guarantee
            b_prior=prior_to_use,
            alpha_pre_selection=alpha_pre_selection,
            alpha_post_selection=alpha_post_selection,
            N_resamples=N_resamples,
            preliminary_gamma=preliminary_gamma,
            aux_split_ratio=aux_split_ratio,
        )

    return configured_func


def load_tuned_hyperparams(dataset_name):
    """Load tuned hyperparameters for a specific dataset"""
    try:
        with open("./results/tuned_hyperparams_skopt_y_scaled.json", "r") as f:
            all_params = json.load(f)
            return all_params.get(dataset_name, {})
    except Exception as e:
        print(f"Warning: Could not load hyperparameters for {dataset_name}: {e}")
        return {}


def train_models_with_residuals(
    X, y, dataset_name, split_size=0.5, number_partitions=-1, random_seed=42
):
    """
    Trains models and residual predictors with data splitting using tuned hyperparameters

    Args:
        X: Full training features
        y: Training targets
        dataset_name: Name of the dataset to load appropriate hyperparameters
        split_size: Proportion of data for residual training split
        number_partitions: Number of partitions for training (-1 for homogeneous)
        random_seed: Random seed for reproducibility

    Returns:
        List of (trained_model, residual_predictor) tuples
    """
    # Load tuned hyperparameters for this dataset
    tuned_params = load_tuned_hyperparams(dataset_name)

    # Split data (uses global random state)
    X_train, X_res, y_train, y_res = train_test_split(X, y, test_size=split_size)

    n_p = len(y_train)
    if number_partitions > 0:
        size_low = math.floor(n_p / number_partitions)
        size_high = math.ceil(n_p / number_partitions)
        kmeans = KMeansConstrained(
            n_clusters=number_partitions,
            size_min=size_low,
            size_max=size_high,
            verbose=False,
        )
        kmeans.fit(X_train)
        clusters = [kmeans.labels_ == i for i in range(number_partitions)]
    elif number_partitions != -1:
        raise ValueError("number_partitions must be -1 or positive integer")

    # Define base models with their default parameters
    base_models = {
        "SVR": SVR,
        "Linear Regression": LinearRegression,
        "Random Forest": RandomForestRegressor,
        "Gradient Boosting": GradientBoostingRegressor,
        "Elastic Net": ElasticNet,
        "Decision Tree": DecisionTreeRegressor,
        "AdaBoost Reg": AdaBoostRegressor,
    }

    trained_pairs = []

    for name, model_class in base_models.items():
        # Get tuned parameters for this model if available
        model_params = tuned_params.get(name, {})

        # Initialize model with tuned parameters
        try:
            model = model_class(**model_params)
        except Exception as e:
            print(
                f"Warning: Could not initialize {name} with tuned parameters for {dataset_name}: {e}"
            )
            print(f"Falling back to default parameters for {name}")
            model = model_class()

        # Cluster selection
        if number_partitions > 0:
            k = number_partitions
            cluster_idx = np.random.randint(k)
            train_idx = clusters[cluster_idx]
            cX_train = X_train[train_idx]
            cY_train = y_train[train_idx]
        else:
            cX_train, cY_train = X_train, y_train

        model.fit(cX_train, cY_train)

        # Calculate residuals
        y_pred = model.predict(X_res)
        residuals = np.abs(y_res - y_pred)

        # Train residual predictor using tuned Random Forest parameters if available
        try:
            # Use the same Random Forest parameters as the main model if available
            residual_params = tuned_params.get("Random Forest", {})
            residual_model = RandomForestRegressor(**residual_params)
            print(
                f"Using tuned Random Forest parameters for residual predictor of {name}"
            )
        except Exception as e:
            print(
                f"Warning: Could not initialize residual predictor for {name} with tuned parameters: {e}"
            )
            print(f"Falling back to default parameters for residual predictor")
            residual_model = RandomForestRegressor()

        residual_model.fit(X_res, residuals)

        trained_pairs.append((name, model, residual_model))

    return trained_pairs


def process_test_point_uci(
    test_idx,
    cal_preds_all_models,
    cal_y_all,
    cal_sigma_all_models,
    test_pred_all_models,
    test_sigma_all_models,
    y_test_all,
    configured_aggregators,
    alpha,
    train_prior_for_methods,
):
    test_pred_point = test_pred_all_models[:, test_idx]
    test_sigma_point = test_sigma_all_models[:, test_idx]
    y_test_point = y_test_all[test_idx]

    point_results = {}
    for name, alg_func in configured_aggregators.items():
        try:
            coverage, length, *_ = alg_func(
                cal_preds=cal_preds_all_models,
                cal_y=cal_y_all,
                cal_sigma=cal_sigma_all_models,
                test_pred=test_pred_point,
                test_sigma=test_sigma_point,
                test_y=y_test_point,
                alpha=alpha,
                b_prior=train_prior_for_methods,
            )
            point_results[name] = (float(coverage), float(length))
        except Exception as e:
            print(
                f"Error in aggregator {name} for test_idx {test_idx} (y_test_point: {y_test_point}): {e}"
            )
            point_results[name] = (np.nan, np.nan)
    return point_results


def run_experiment_for_dataset(
    dataset_name,
    seed,
    alpha,
    all_configured_aggregators,
    train_ratio,
    cal_ratio,
    test_ratio,
    num_partitions,
    args,
):
    set_seed(seed)

    X, y, ds_name_from_loader = load_uci_dataset(dataset_name)
    if X is None:
        print(f"Skipping dataset {dataset_name} as it could not be loaded.")
        return None, None

    total_points = len(X)
    print(f"\n{'='*60}")
    print(f"Dataset: {ds_name_from_loader}")
    print(f"Total number of points: {total_points}")
    print(
        f"Split ratios - Train: {train_ratio:.2f}, Cal: {cal_ratio:.2f}, Test: {test_ratio:.2f}"
    )
    print(f"Expected split sizes:")
    print(f"  - Training: {int(total_points * train_ratio)} points")
    print(f"  - Calibration: {int(total_points * cal_ratio)} points")
    print(f"  - Testing: {int(total_points * test_ratio)} points")
    print(f"{'='*60}\n")

    # Verify ratios sum to 1
    if not np.isclose(train_ratio + cal_ratio + test_ratio, 1.0):
        raise ValueError(
            f"Ratios must sum to 1, but got train_ratio={train_ratio}, cal_ratio={cal_ratio}, test_ratio={test_ratio}"
        )

    # Split data into train, calibration, and test sets
    # First split out test set
    X_train_cal, X_test, y_train_cal, y_test = train_test_split(
        X, y, test_size=test_ratio, random_state=seed
    )

    # Then split remaining data into train and calibration
    cal_ratio_adjusted = cal_ratio / (
        1 - test_ratio
    )  # Adjust cal_ratio for the remaining data
    X_train, X_cal, y_train, y_cal = train_test_split(
        X_train_cal, y_train_cal, test_size=cal_ratio_adjusted, random_state=seed
    )

    print(f"Actual split sizes:")
    print(
        f"  - Training: {len(X_train)} points (used for both regressors and residual predictors)"
    )
    print(f"    * Internal split for residual training: {len(X_train)//2} points each")
    print(f"  - Calibration: {len(X_cal)} points")
    print(f"  - Testing: {len(X_test)} points")
    print(f"{'='*60}\n")

    scaler = StandardScaler()
    y_scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_cal = scaler.transform(X_cal)
    X_test = scaler.transform(X_test)

    # Convert y to numpy arrays and ensure they are 2D for the scaler
    y_train_np = y_train.to_numpy().reshape(-1, 1)
    y_cal_np = y_cal.to_numpy().reshape(-1, 1)
    y_test_np = y_test.to_numpy().reshape(-1, 1)

    # Fit Y scaler on training data and transform all Y splits
    y_train = y_scaler.fit_transform(y_train_np).ravel()
    y_cal = y_scaler.transform(y_cal_np).ravel()
    y_test = y_scaler.transform(y_test_np).ravel()

    print(f"Dataset: {ds_name_from_loader}, Seed: {seed}")
    print(
        f"Shapes: X_train: {X_train.shape}, y_train: {y_train.shape}, "
        f"X_cal: {X_cal.shape}, y_cal: {y_cal.shape}, "
        f"X_test: {X_test.shape}, y_test: {y_test.shape}"
    )

    if X_train.shape[0] < 2 or X_cal.shape[0] < 2 or X_test.shape[0] == 0:
        print(
            f"Skipping {ds_name_from_loader} for seed {seed} due to insufficient train/cal/test data after split."
        )
        return None, None

    # Train models using all training data for both regressors and residual predictors
    trained_model_pairs = train_models_with_residuals(
        X_train,
        y_train,
        dataset_name=dataset_name,
        split_size=0.5,  # Use half of training data for residual training
        number_partitions=num_partitions,
        random_seed=seed,
    )
    if not trained_model_pairs:
        print(f"No models trained for {ds_name_from_loader}, seed {seed}. Skipping.")
        return None, None

    M = len(trained_model_pairs)

    cal_preds_all = np.array([m[1].predict(X_cal) for m in trained_model_pairs])
    cal_sigma_all = np.clip(
        np.array([m[2].predict(X_cal) for m in trained_model_pairs]), 0, np.inf
    )
    test_pred_all = np.array([m[1].predict(X_test) for m in trained_model_pairs])
    test_sigma_all = np.clip(
        np.array([m[2].predict(X_test) for m in trained_model_pairs]), 0, np.inf
    )

    prior_train_preds_sigma = np.array(
        [m[2].predict(X_train) for m in trained_model_pairs]
    )
    avg_prior_train_sigma = np.sum(prior_train_preds_sigma, axis=1)
    train_prior = softmax(ETA_PRIOR * -avg_prior_train_sigma)

    #TODO check this part:
    # Calculate individual model interval lengths for each test point
    indiv_scores_cal = np.abs(cal_preds_all - y_cal[None, :]) / (cal_sigma_all + 1e-8)
    indiv_sorted_scores_cal = np.sort(indiv_scores_cal, axis=1)
    k_quant = math.ceil((len(y_cal) + 1) * (1 - alpha))
    k_quant = min(max(k_quant, 1), len(y_cal))
    indiv_q_hats = indiv_sorted_scores_cal[:, k_quant - 1]

    indiv_lengths = np.zeros((M, len(y_test)))
    for i in range(M):
        indiv_lengths[i] = 2 * test_sigma_all[i] * indiv_q_hats[i]
    
    # Find model with smallest interval for each test point
    min_length_indices = np.argmin(indiv_lengths, axis=0)
    min_lengths = np.min(indiv_lengths, axis=0)
    
    # Print summary statistics
    print("\nBase Conformal Predictor Performance Summary:")
    print("--------------------------------------------")
    print("Average interval lengths per model:")
    for idx, (model_name, _, _) in enumerate(trained_model_pairs):
        avg_len = np.mean(indiv_lengths[idx])
        print(f"{model_name:20s}: {avg_len:.4f}")
    
    print("\nPointwise minimum selection:")
    print(f"Average length when selecting minimum: {np.mean(min_lengths):.4f}")
    
    # Count how often each model is selected as minimum
    model_counts = np.bincount(min_length_indices, minlength=M)
    print("\nNumber of times each model achieved minimum length:")
    for idx, (model_name, _, _) in enumerate(trained_model_pairs):
        count = model_counts[idx]
        percentage = (count / len(y_test)) * 100
        print(f"{model_name:20s}: {count:4d} times ({percentage:5.1f}%)")
    print("--------------------------------------------\n")


    # Create two separate result dictionaries
    individual_results = {
        "Dataset": ds_name_from_loader,
        "Seed": seed,
        "Alpha": alpha,
        "NumPartitions": num_partitions,
        "N_cal": len(y_cal),
        "N_test": len(y_test),
    }

    aggregator_results = {
        "Dataset": ds_name_from_loader,
        "Seed": seed,
        "Alpha": alpha,
        "NumPartitions": num_partitions,
        "N_cal": len(y_cal),
        "N_test": len(y_test),
    }

    k_quant = math.ceil((len(y_cal) + 1) * (1 - alpha))
    if k_quant > len(y_cal):
        k_quant = len(y_cal)
    if k_quant < 1:
        k_quant = 1

    indiv_scores_cal = np.abs(cal_preds_all - y_cal[None, :]) / (cal_sigma_all + 1e-8)
    indiv_sorted_scores_cal = np.sort(indiv_scores_cal, axis=1)
    indiv_q_hats = indiv_sorted_scores_cal[:, k_quant - 1]

    # Store individual model results
    for idx, (model_name, _, _) in enumerate(trained_model_pairs):
        individual_results[f"{model_name}_MAE"] = mean_absolute_error(
            y_test, test_pred_all[idx]
        )
        individual_results[f"{model_name}_R2"] = r2_score(y_test, test_pred_all[idx])
        avg_len_indiv = 2 * np.mean(test_sigma_all[idx, :] * indiv_q_hats[idx])
        coverage_indiv = np.mean(
            np.abs(y_test - test_pred_all[idx, :])
            <= test_sigma_all[idx, :] * indiv_q_hats[idx]
        )
        individual_results[f"{model_name}_AvgLen"] = avg_len_indiv
        individual_results[f"{model_name}_Coverage"] = coverage_indiv

    results_per_test_point = Parallel(n_jobs=-1, backend="loky")(
        delayed(process_test_point_uci)(
            test_idx=i,
            cal_preds_all_models=cal_preds_all,
            cal_y_all=y_cal,
            cal_sigma_all_models=cal_sigma_all,
            test_pred_all_models=test_pred_all,
            test_sigma_all_models=test_sigma_all,
            y_test_all=y_test,
            configured_aggregators=all_configured_aggregators,
            alpha=alpha,
            train_prior_for_methods=train_prior,
        )
        for i in tqdm(
            range(len(y_test)),
            desc=f"Processing test points for {ds_name_from_loader} (Seed {seed})",
        )
    )

    agg_coverage_sum = {k: 0.0 for k in all_configured_aggregators.keys()}
    agg_length_sum = {k: 0.0 for k in all_configured_aggregators.keys()}
    valid_counts = {k: 0 for k in all_configured_aggregators.keys()}

    for point_res_dict in results_per_test_point:
        for name, (cov, leng) in point_res_dict.items():
            if not np.isnan(cov) and not np.isnan(leng):
                agg_coverage_sum[name] += cov
                agg_length_sum[name] += leng
                valid_counts[name] += 1

    # Store aggregator results
    for name in all_configured_aggregators.keys():
        if valid_counts[name] > 0:
            mean_cov = agg_coverage_sum[name] / valid_counts[name]
            mean_len = agg_length_sum[name] / valid_counts[name]
            aggregator_results[f"{name}_CovMean"] = mean_cov
            aggregator_results[f"{name}_LenMean"] = mean_len
        else:
            aggregator_results[f"{name}_CovMean"] = np.nan
            aggregator_results[f"{name}_LenMean"] = np.nan

    return individual_results, aggregator_results


def main(args):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    datasets_str = "_".join(sorted([d.replace("_", "") for d in args.datasets]))[:50]
    partition_str = f"P{args.num_partitions}" if args.num_partitions >= 0 else "Homog"

    # Create two separate filenames
    base_filename_individual = f"UCI_Individual_{datasets_str}_{partition_str}_Alpha{str(args.alpha).replace('.', '')}_{timestamp}.csv"
    base_filename_aggregator = f"UCI_Aggregator_{datasets_str}_{partition_str}_Alpha{str(args.alpha).replace('.', '')}_{timestamp}.csv"

    results_dir = args.output_dir
    os.makedirs(results_dir, exist_ok=True)
    output_filepath_individual = os.path.join(results_dir, base_filename_individual)
    output_filepath_aggregator = os.path.join(results_dir, base_filename_aggregator)

    print(f"Individual results will be saved to: {output_filepath_individual}")
    print(f"Aggregator results will be saved to: {output_filepath_aggregator}")

    all_configured_aggregators = {}
    # 1. Existing Methods
    existing_base_funcs = {
        "ModSel": RR.ModSel_rescaleRes,
        "YKbaseline": RR.YKbaseline_rescaleRes,
        "YK_adj": RR.YK_adj_rescaleRes,
        "YKsplit": RR.YKsplit_rescaleRes,
    }
    for name, base_func in existing_base_funcs.items():
        fixed_params = {}
        if name == "YKsplit":
            fixed_params["split_portion"] = args.split_portion
        all_configured_aggregators[name] = wrap_existing_conformal(
            base_func, **fixed_params
        )

    stable_configs_from_synthetic = {
        "StableEta0.1": {"eta": 0.1},
        "StableEta0.5": {"eta": 0.5},
        "StableEta2.0": {"eta": 2.0},
    }
    for name_suffix, params in stable_configs_from_synthetic.items():
        # TP (Train Prior) versions
        # all_configured_aggregators[f"TP_{name_suffix}"] = wrap_stable_conformal(
        #     RR.stable_conformal, **params, ignore_prior=False
        # )
        # UP (Uniform Prior) versions
        all_configured_aggregators[f"UP_{name_suffix}"] = wrap_stable_conformal(
            RR.stable_conformal, **params, ignore_prior=True
        )

    # 3. Adaptive Stable Conformal Methods
    adaptive_stable_configs_from_synthetic = {  # From synthetic_exp.py
        "AdaStable0.50": {"ratio": 0.50},
        # "AdaStable0.75": {"ratio": 0.75},
        "AdaStable0.90": {"ratio": 0.90},
    }
    for name_suffix, params in adaptive_stable_configs_from_synthetic.items():
        # # TP versions
        # all_configured_aggregators[f"TP_{name_suffix}"] = wrap_adaptive_stable_conformal(
        #     RR.adaptive_stable_conformal, **params, ignore_prior=False
        # )
        # UP versions
        all_configured_aggregators[f"UP_{name_suffix}"] = (
            wrap_adaptive_stable_conformal(
                RR.adaptive_stable_conformal, **params, ignore_prior=True
            )
        )

    # 4. Internal Split Calibration Methods ("InstCal")
    # From synthetic_exp.py:
    internal_split_base_configs = {
        "InstCal0.2": {"alpha_post_selection": 0.2},
        "InstCal1.0": {"alpha_post_selection": 1.0},
    }
    common_instcal_params = {
        "base_func": RR.calibrate_after_selection_resampling,
        "alpha_pre_selection": 0.1,  # Hardcoded in synthetic
        "N_resamples": args.N_resamples,
        "preliminary_gamma": 1 - args.alpha,  # Derived from args.alpha
        "aux_split_ratio": 0.5,  # Hardcoded in synthetic (as default for wrapper, and used this way)
    }

    for name_prefix, specific_params in internal_split_base_configs.items():
        current_config = {**common_instcal_params, **specific_params}
        # TP versions (if desired, synthetic only has UP for these)
        # all_configured_aggregators[f"{name_prefix}-TP"] = wrap_internal_split_calibration(
        #     **current_config, ignore_prior=False
        # )
        # UP versions (matching synthetic_exp.py)
        all_configured_aggregators[f"{name_prefix}-UP"] = (
            wrap_internal_split_calibration(**current_config, ignore_prior=True)
        )

    print(f"Configured {len(all_configured_aggregators)} aggregators.")
    # print("Aggregator keys:", list(all_configured_aggregators.keys()))

    all_individual_results = []
    all_aggregator_results = []

    for dataset_name in args.datasets:
        for seed_val in range(args.num_seeds):
            print(f"\n{'='*60}")
            print(
                f"STARTING Dataset: {dataset_name}, Seed: {seed_val+1}/{args.num_seeds}"
            )
            print(f"{'='*60}")

            individual_results, aggregator_results = run_experiment_for_dataset(
                dataset_name=dataset_name,
                seed=seed_val,
                alpha=args.alpha,
                all_configured_aggregators=all_configured_aggregators,
                train_ratio=args.train_ratio,
                cal_ratio=args.cal_ratio,
                test_ratio=args.test_ratio,
                num_partitions=args.num_partitions,
                args=args,
            )

            if individual_results and aggregator_results:
                all_individual_results.append(individual_results)
                all_aggregator_results.append(aggregator_results)

                # Save both results whenever we have new data
                if all_individual_results:
                    current_individual_df = pd.DataFrame(all_individual_results)
                    current_individual_df.to_csv(
                        output_filepath_individual, index=False, mode="w", header=True
                    )
                    print(
                        f"Individual results updated and saved to {output_filepath_individual}"
                    )

                if all_aggregator_results:
                    current_aggregator_df = pd.DataFrame(all_aggregator_results)
                    current_aggregator_df.to_csv(
                        output_filepath_aggregator, index=False, mode="w", header=True
                    )
                    print(
                        f"Aggregator results updated and saved to {output_filepath_aggregator}"
                    )
            else:
                print(
                    f"Skipped run for Dataset: {dataset_name}, Seed: {seed_val} - no results to append."
                )

    if not all_individual_results or not all_aggregator_results:
        print("No results were generated in any run.")
    else:
        print(f"\nAll experiments complete. Final results saved in:")
        print(f"Individual results: {output_filepath_individual}")
        print(f"Aggregator results: {output_filepath_aggregator}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run UCI Conformal Prediction Experiments"
    )
    parser.add_argument(
        "--datasets",
        nargs="+",
        default=["ABALONE", "BIKE_SHARING", "CALIFORNIA_HOUSING"],
        help="List of UCI datasets to use",
    )
    parser.add_argument(
        "--num_seeds", type=int, default=10, help="Number of random seeds"
    )
    parser.add_argument(
        "--alpha", type=float, default=0.1, help="Target miscoverage rate"
    )

    # Three clear ratios that must sum to 1
    parser.add_argument(
        "--train_ratio",
        type=float,
        default=0.8,
        help="Proportion of data for training regressors and residual predictors",
    )
    parser.add_argument(
        "--cal_ratio",
        type=float,
        default=0.1,
        help="Proportion of data for calibration",
    )
    parser.add_argument(
        "--test_ratio", type=float, default=0.1, help="Proportion of data for testing"
    )

    parser.add_argument(
        "--num_partitions",
        type=int,
        default=5,
        help="Num partitions for training models (-1 for homogeneous)",
    )
    parser.add_argument(
        "--split_portion",
        type=float,
        default=0.5,
        help="Internal split portion for YKsplit method",
    )
    parser.add_argument(
        "--N_resamples", type=int, default=1, help="N for resampling in InstCal methods"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="uci_results",
        help="Directory to save results",
    )

    args = parser.parse_args()

    # Verify ratios sum to 1
    if not np.isclose(args.train_ratio + args.cal_ratio + args.test_ratio, 1.0):
        raise ValueError(
            f"train_ratio ({args.train_ratio}), cal_ratio ({args.cal_ratio}), and test_ratio ({args.test_ratio}) must sum to 1"
        )

    main(args)
