import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import linear_model
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os
import gc
import click
import json

from dataset import load_dataset

# Configuration
USE_PARALLEL = True  # Set to False to use sequential version
MAX_PROCESSES = 1000  # Maximum number of parallel processes (adjust based on GPU memory)
OUTPUT_DIR = "outputs"

# Set device for GPU acceleration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class PinballLoss(nn.Module):
    """Pinball loss for quantile regression"""
    def __init__(self, quantile):
        super(PinballLoss, self).__init__()
        self.quantile = quantile
    
    def forward(self, pred, target):
        """
        Compute pinball loss
        pred: predicted values
        target: true values
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        """
        error = target - pred
        loss = torch.where(error >= 0, 
                          self.quantile * error, 
                          (self.quantile - 1) * error)
        return torch.mean(loss)


class QuantileLinearModel(nn.Module):
    """Linear model for quantile regression"""
    def __init__(self, input_dim):
        super(QuantileLinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, x):
        return self.linear(x).squeeze()



class QuantileTwoLayerModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileTwoLayerModel, self).__init__()
        self.linear = nn.Linear(input_dim, 10)
        self.linear2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x).squeeze()
        return x


class QuantileThreeLayerModel(nn.Module):
    """Linear model for quantile regression"""

    def __init__(self, input_dim):
        super(QuantileThreeLayerModel, self).__init__()
        self.linear = nn.Linear(input_dim, 10)
        self.linear2 = nn.Linear(10, 10)
        self.linear3 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        return x.squeeze()



def train_quantile_model(X_train, y_train, quantile, epochs=1000, lr=0.01, batch_size=64, patience=50, device=None, kwargs={}):
    """
    Train a PyTorch linear model using pinball loss for a specific quantile
    
    Args:
        X_train: training features
        y_train: training targets
        quantile: quantile level (e.g., 0.05 for 5th percentile)
        epochs: maximum number of training epochs
        lr: learning rate
        patience: early stopping patience
    
    Returns:
        trained model
    """
    # Initialize device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Convert to PyTorch tensors and move to device
    X_tensor = torch.FloatTensor(X_train).to(device)
    y_tensor = torch.FloatTensor(y_train).to(device)
    
    # Initialize model and move to device
    input_dim = X_train.shape[1]
    if "model" in kwargs:
        if kwargs["model"] == "linear":
            model = QuantileLinearModel(input_dim)
        elif kwargs["model"] == "two_layer":
            model = QuantileTwoLayerModel(input_dim)
        elif kwargs["model"] == "three_layer":
            model = QuantileThreeLayerModel(input_dim)
        else:
            raise ValueError(f"Invalid model: {kwargs['model']}")
    else:
        model = QuantileLinearModel(input_dim)

    model = model.to(device)
    
    # Loss function and optimizer
    criterion = PinballLoss(quantile)
    if "optim" in kwargs:
        if kwargs["optim"] == "SGD":
            optimizer = optim.SGD(model.parameters(), lr=lr)
        elif kwargs["optim"] == "MomentumSGD":
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        elif kwargs["optim"] == "Adam":
            optimizer = optim.Adam(model.parameters(), lr=lr)
        elif kwargs["optim"] == "AdamW":
            optimizer = optim.AdamW(model.parameters(), lr=lr)
        elif kwargs["optim"] == "VR":
            # variance reduction
            # optimizer = VR(model.parameters(), lr=lr)
            raise NotImplementedError("VR optimizer not implemented")
        else:
            raise ValueError(f"Invalid optimizer: {kwargs['optim']}")
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Training loop
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        for i in range(0, len(X_tensor), batch_size):
            X_batch = X_tensor[i:i+batch_size]
            y_batch = y_tensor[i:i+batch_size]

            # Forward pass
            optimizer.zero_grad()
            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                break
            
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return model, loss.item()

def _transform_dataset(X_train, y_train, X_test, y_test, train_ratio, calibration_ratio, seed=None):

    # reshape the data
    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    X_test = np.asarray(X_test)
    y_test = np.asarray(y_test)

    # input dimensions
    n_train_original = X_train.shape[0]
    in_shape = X_train.shape[1]

    # set seed for splitting the data into proper train and calibration
    if seed is not None:
        np.random.seed(seed)
    idx = np.random.permutation(n_train_original)

    # divide the data into proper training set and calibration set
    n_train = int(np.floor(n_train_original * train_ratio))
    n_calibration = int(np.floor(n_train_original * calibration_ratio))
    
    # Ensure non-overlapping splits
    idx_train = idx[:n_train]
    idx_cal = idx[-n_calibration:]

    # zero mean and unit variance scaling of the train and test features
    scalerX = StandardScaler()
    scalerX = scalerX.fit(X_train[idx_train])
    X_train = scalerX.transform(X_train)
    X_test = scalerX.transform(X_test)
    
    # scale the labels by dividing each by the mean absolute response
    mean_ytrain = np.mean(np.abs(y_train[idx_train]))
    y_train = np.squeeze(y_train) / mean_ytrain
    y_test = np.squeeze(y_test) / mean_ytrain

    X_train = X_train.astype(np.float32)
    X_test = X_test.astype(np.float32)
    y_train = y_train.astype(np.float32)
    y_test = y_test.astype(np.float32)
    idx_train = idx_train.astype(np.int32)
    # idx_cal = idx_cal.astype(np.int32)
    return X_train, X_test, y_train, y_test, idx_train, idx_cal


def successive_halving_lr0_tuning(X_train, y_train, gamma, epochs=1, 
                                 lr0_candidates=None, n_iterations=3, 
                                 eta=2, min_budget=1, debug=False, device=None, kwargs={}):
    """
    Successive Halving algorithm for tuning lr0 hyperparameter.
    
    Parameters:
    -----------
    X_train, y_train : training data
    gamma : quantile level
    rng : random number generator
    epochs : number of training epochs
    lr0_candidates : list of learning rates to try (if None, will generate automatically)
    n_iterations : number of successive halving iterations
    eta : elimination factor (how many candidates to eliminate each round)
    min_budget : minimum number of candidates to keep
    
    Returns:
    --------
    best_lr0 : best learning rate found
    best_loss : corresponding loss value
    tuning_history : list of (iteration, candidates, losses) for each round
    """
    def pprint(text):
        if debug:
            print(text)
    
    # Generate learning rate candidates if not provided
    if lr0_candidates is None:
        # Use log-uniform distribution for learning rates
        lr0_candidates = np.logspace(-5, 0, 20)  # From 1e-5 to 1e-1
    
    tuning_history = []
    current_candidates = lr0_candidates.copy()
    
    pprint(f"Starting Successive Halving with {len(current_candidates)} candidates")
    pprint(f"Learning rate range: [{min(current_candidates):.2e}, {max(current_candidates):.2e}]")
    
    for iteration in range(n_iterations):
        pprint(f"\n--- Iteration {iteration + 1} ---")
        pprint(f"Evaluating {len(current_candidates)} candidates")
        
        # Evaluate all current candidates
        candidate_losses = []
        
        for i, lr0 in enumerate(current_candidates):
            # Train model with current lr0
            model, final_loss = train_quantile_model(
                X_train,
                y_train,
                quantile=gamma,
                epochs=epochs,
                lr=lr0,
                device=device,
                kwargs=kwargs,
        )
            
            # Use the last (final) loss value as evaluation metric
            # final_loss = model.loss_list[-1]
            candidate_losses.append(final_loss)
            
            pprint(f"  lr0={lr0:.2e}: final_loss={final_loss:.6f}")
        
        # Store history
        tuning_history.append({
            'iteration': iteration + 1,
            'candidates': current_candidates.copy(),
            'lr': lr0,
            'losses': candidate_losses.copy()
        })
        
        # Select top candidates for next round
        if iteration < n_iterations - 1:  # Don't eliminate on last iteration
            # Sort by loss (lower is better for quantile regression)
            sorted_indices = np.argsort(candidate_losses)
            
            # Keep top candidates (eliminate worst ones)
            n_keep = max(min_budget, len(current_candidates) // eta)
            top_indices = sorted_indices[:n_keep]
            
            current_candidates = current_candidates[top_indices]
            pprint(f"Kept top {len(current_candidates)} candidates for next round")
        else:
            # Final iteration - select best
            best_idx = np.argmin(candidate_losses)
            best_lr0 = current_candidates[best_idx]
            best_loss = candidate_losses[best_idx]
            
            pprint(f"\nFinal best: lr0={best_lr0:.2e}, loss={best_loss:.6f}")

    model, _ = train_quantile_model(
        X_train, y_train, quantile=gamma, epochs=epochs, lr=best_lr0, device=device
    )


    return model, best_lr0, best_loss, tuning_history


def run_experiment(
    dataset_name, 
    alpha=0.1, 
    calibration_ratio=0.2,
    test_ratio=0.2,
    train_ratio=0.5,
    n_iterations=3,
    lr0_candidates=np.logspace(-5, 0, 20),
    kwargs={},
    seed=42,
    device=None
):
    kwargs = dict(kwargs)
    # Initialize device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    quantiles = [alpha / 2, 1 - alpha / 2]
    data = {"seed": seed, "dataset_name": dataset_name, "alpha": alpha, "calibration_ratio": calibration_ratio, "train_ratio": train_ratio,
    **kwargs,
}
    
    ## Fix seed
    data_seed = 42
    random.seed(data_seed)
    np.random.seed(data_seed)
    torch.manual_seed(data_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(data_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    X, y = load_dataset(dataset_name)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=data_seed)
    X_train, X_test, y_train, y_test, idx_train, idx_cal = _transform_dataset(X_train, y_train, X_test, y_test, train_ratio=train_ratio, calibration_ratio=calibration_ratio, seed=seed)

    ## Fix seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # print(f"\nTraining PyTorch linear models for quantiles: {quantiles}")

    models = {}
    predictions_calib = {}
    predictions_test = {}
    best_losses_dict = {}
    for quantile in quantiles:
        model, _, best_loss, _ = successive_halving_lr0_tuning(
            X_train[idx_train], 
            y_train[idx_train], 
            gamma=quantile,
            epochs=1,
            device=device,
            # debug=True,
            n_iterations=n_iterations,
            lr0_candidates=lr0_candidates,
            kwargs=kwargs,
        )
        
        models[quantile] = model
        best_losses_dict[quantile] = best_loss

        # Make predictions
        with torch.no_grad():
            X_test_tensor = torch.FloatTensor(X_test).to(device)
            pred = model(X_test_tensor).cpu().numpy()
            predictions_test[quantile] = pred

            X_cal_tensor = torch.FloatTensor(X_train[idx_cal]).to(device)
            pred_cal = model(X_cal_tensor).cpu().numpy()
            predictions_calib[quantile] = pred_cal

    # Calculate prediction intervals
    t_lo = predictions_calib[quantiles[0]]
    t_hi = predictions_calib[quantiles[1]]

    def compute_cqr_nonconformity_score(t_lo, t_hi, y):
        return np.max(np.stack([t_lo - y, y - t_hi], axis=0), axis=0)

    cqr_nc = compute_cqr_nonconformity_score(t_lo, t_hi, y_train[idx_cal])
    q = np.quantile(cqr_nc, 1 - alpha, method="higher")
    # print(f"1-alpha: {1-alpha}, q: {q} median = {np.median(cqr_nc)}")

    # Coverage (percentage of test points within prediction interval)
    t_lo_test = predictions_test[quantiles[0]]
    t_hi_test = predictions_test[quantiles[1]]

    data['crossing_mean'] = np.mean(t_hi_test > t_lo_test).item()

    t_lo_test, t_hi_test = np.min(np.c_[t_lo_test, t_hi_test], axis=1), np.max(np.c_[t_lo_test, t_hi_test], axis=1)

    coverage = np.mean((y_test >= t_lo_test - q) & (y_test <= t_hi_test + q))
    data['coverage'] = coverage.item()
    data['interval_width'] = (t_hi_test - t_lo_test).mean().item()
    data['predicted_interval_width'] = (t_hi_test - t_lo_test + 2 * q).mean().item()
    data['delta'] = t_hi_test - t_lo_test
    data['q'] = q.item()
    data["n_iterations"] = n_iterations
    data["len(lr0_candidates)"] = len(lr0_candidates)
    data["best_losses"] = json.dumps(best_losses_dict)

    
    # Clear GPU memory and run garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    return data



def run_single_experiment(args):
    """Wrapper function for parallel execution"""
    (
        dataset_name,
        alpha,
        calibration_ratio,
        train_ratio,
        n_iterations,
        lr0_candidates,
        kwargs,
        seed,
    ) = args
    
    try:
        # Initialize CUDA device in subprocess
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        result = run_experiment(
            dataset_name=dataset_name,
            alpha=alpha,
            calibration_ratio=calibration_ratio,
            test_ratio=0.2,
            train_ratio=train_ratio,
            n_iterations=n_iterations,
            lr0_candidates=lr0_candidates,
            kwargs=kwargs,
            seed=seed,
            device=device
        )
        return result
    except Exception as e:
        print(f"Error in experiment {args}: {e}")
        raise e
        return None



@click.group()
def cli():
    pass


################
@cli.command()
def grid_budget_exploration():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20, 50))
    calibration_ratios = [0.05, 0.1, 0.15, 0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]

    filename = "grid_budget_exploration_cqr"  # NOTE:
    output_file = f"{OUTPUT_DIR}/{filename}.csv"

    experiment_args = []
    for dataset_name in ["meps_20", "meps_19", "cal_housing"]:
        # Create all experiment combinations
        for seed in seeds:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        for n_iterations, lr0_candidates in [
                            (20, np.logspace(-6, 0, 80)),
                        ]:
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")
    return df


@cli.command()
def different_optimizers_datasets():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20))
    calibration_ratios = [0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]
    taskname = "different_optimizers_cqr"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    # output_file = f"{OUTPUT_DIR}/{taskname}/result.csv"
    # output_file = f"{OUTPUT_DIR}/{taskname}/result_extended_dataset.csv"
    output_file = f"{OUTPUT_DIR}/{taskname}/result_extended_optimizers.csv"

    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20"]:
        # for dataset_name in ["cal_housing", "abalone", "cpu_small"]:
        for dataset_name in ["meps_19", "meps_20", "cal_housing", "abalone", "cpu_small"]:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        # for optim in ["SGD", "MomentumSGD", "Adam", "AdamW"]:
                        for optim in ["AdamW"]:
                            n_iterations = 10
                            lr0_candidates = np.logspace(-6, 0, 20)
                            kwargs = [("optim", optim)]
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    kwargs,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df


@cli.command()
def different_models_datasets():
    mp.set_start_method("spawn", force=True)

    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Define experiment parameters
    seeds = list(range(20))
    calibration_ratios = [0.2]
    train_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    alphas = [0.01, 0.05, 0.1, 0.2]
    # seeds = list(range(1))
    # calibration_ratios = [0.2]
    # train_ratios = [0.2]
    # alphas = [0.05]
    taskname = "different_models_cqr"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    output_file = f"{OUTPUT_DIR}/{taskname}/result_sgd.csv"

    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20", "cal_housing"]:
        for dataset_name in ["cal_housing", "abalone", "cpu_small"]:
            for calibration_ratio in calibration_ratios:
                for train_ratio in train_ratios:
                    for alpha in alphas:
                        for model in ["linear", "two_layer", "three_layer"]:
                            n_iterations = 10
                            lr0_candidates = np.logspace(-6, 0, 20)
                            kwargs = [("model", model), ("optim", "SGD")]
                            experiment_args.append(
                                (
                                    dataset_name,
                                    alpha,
                                    calibration_ratio,
                                    train_ratio,
                                    n_iterations,
                                    lr0_candidates,
                                    kwargs,
                                    seed,
                                )
                            )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df


@cli.command()
def data_allocation_guidance():
    mp.set_start_method("spawn", force=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    seeds = list(range(20))
    # calibration_ratios = [0.2]
    # train_ratios = [0.8]
    # calibration_ratios = [0.8]
    # train_ratios = [0.01]
    train_ratios = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
    # alphas = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15, 0.2]
    alphas = [0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008, 0.009]
    taskname = "data_allocation_guidance_cqr"  # NOTE:
    os.makedirs(f"{OUTPUT_DIR}/{taskname}", exist_ok=True)
    output_file = f"{OUTPUT_DIR}/{taskname}/result_var_train_1.csv"
    # Create all experiment combinations
    experiment_args = []
    for seed in seeds:
        # for dataset_name in ["meps_19", "meps_20", "cal_housing"]:
        for dataset_name in ["cpu_small"]:
            for train_ratio in train_ratios:
                calibration_ratio = 1 - train_ratio
                for alpha in alphas:
                    n_iterations = 10
                    lr0_candidates = np.logspace(-6, 0, 20)
                    kwargs = [("model", "linear")]
                    experiment_args.append(
                        (
                            dataset_name,
                            alpha,
                            calibration_ratio,
                            train_ratio,
                            n_iterations,
                            lr0_candidates,
                            kwargs,
                            seed,
                        )
                    )

    print(f"Total experiments to run: {len(experiment_args)}")

    # Determine number of processes (use CPU count, but limit for GPU memory)
    num_processes = min(mp.cpu_count(), MAX_PROCESSES)
    print(f"Using {num_processes} processes")

    # Run experiments in parallel
    with mp.Pool(processes=num_processes) as pool:
        results = list(
            tqdm(
                pool.imap_unordered(run_single_experiment, experiment_args),
                total=len(experiment_args),
                desc="Running experiments",
            )
        )

    # Filter out None results (failed experiments)
    results = [r for r in results if r is not None]

    # Convert to DataFrame and save
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)

    print(f"Completed {len(results)} experiments successfully")
    print(f"Results saved to {output_file}")

    return df

if __name__ == "__main__":
    cli()
