import os
import argparse
import gzip
import json
import numpy as np
import pandas as pd
import openml
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
import datetime
import multiprocessing as mp
from functools import partial
import time
from tqdm import tqdm

# Configure OpenML cache directory
openml.config.cache_directory = "./openml_cache"
os.makedirs(openml.config.cache_directory, exist_ok=True)

# Base directories for dataset storage
FULL_DATASET_DIR = Path("full_datasets")
SEGMENTED_DATASET_DIR = Path("segmented_datasets")

# Create a lock for file operations
file_lock = mp.Lock()


def read_task_list(file_path: str) -> List[int]:
    """Read task IDs from a file."""
    task_ids = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith('#'):
                try:
                    task_id = int(line)
                    task_ids.append(task_id)
                except ValueError:
                    print(f"Skipping invalid task ID: {line}")
    return task_ids


def calculate_optimal_segment_size(total_instances: int, min_size: int = 500, max_size: int = 1000) -> int:
    """
    Calculate the optimal segment size to divide the dataset evenly.
    Tries to find a size between min_size and max_size that divides the dataset
    with minimal remainder.
    
    For very small datasets (smaller than min_size), returns the total_instances.
    """
    # Handle very small datasets - just use one segment with all data
    if total_instances < min_size:
        print(f"Dataset is smaller than minimum segment size ({min_size}). Using one segment with {total_instances} instances.")
        return total_instances
    
    best_remainder = total_instances
    optimal_size = min_size
    
    # Try different segment sizes to find the one with minimal remainder
    for size in range(min_size, min(max_size + 1, total_instances + 1)):
        num_segments = total_instances // size
        if num_segments == 0:  # This shouldn't happen now with the check above
            continue
            
        remainder = total_instances % size
        
        # If we found a perfect division, return immediately
        if remainder == 0:
            return size
            
        # Update if this remainder is smaller than the best so far
        if remainder < best_remainder:
            best_remainder = remainder
            optimal_size = size
    
    return optimal_size


def preprocess_dataset(task_id: int, min_segment_size: int = 500, max_segment_size: int = 1000,
                      test_ratio: float = 0.15, val_ratio: float = 0.15, 
                      random_seed: int = 42) -> Dict:
    """Download and preprocess a dataset from OpenML."""
    print(f"Processing OpenML task {task_id}...")
    
    # Get the OpenML task and dataset
    task = openml.tasks.get_task(task_id=task_id)
    dataset = task.get_dataset()
    
    # Extract data details
    dataset_name = f"openml_task_{task_id}"
    print(f"Dataset name: {dataset.name}")
    
    # Get the data in dataframe format
    X, y, categorical_indicator, feature_names = dataset.get_data(
        dataset_format='dataframe',
        target=dataset.default_target_attribute
    )
        
        # Ensure y is a numpy array
    if hasattr(y, 'values'):
        y_values = y.values
    else:
        y_values = y
            
    # Separate numerical and categorical columns
    categorical_features = [i for i, is_cat in enumerate(categorical_indicator) if is_cat]
    numerical_features = [i for i, is_cat in enumerate(categorical_indicator) if not is_cat]
    
    # Create preprocessing pipelines
    numeric_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='median')),
        ('scaler', StandardScaler())
    ])
    
    categorical_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
        ('onehot', OneHotEncoder(drop='first', sparse_output=False, handle_unknown='ignore'))
    ])
    
    # Combine transformers
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', numeric_transformer, numerical_features),
            ('cat', categorical_transformer, categorical_features)
        ],
        remainder='drop'
    )
    
    # Fit and transform the data
    X_transformed = preprocessor.fit_transform(X).astype(np.float32)
    
    # Determine task type and transform target
    if task.task_type == "Supervised Classification":
        le = LabelEncoder()
        y_transformed = le.fit_transform(y_values)
        n_unique_labels = len(np.unique(y_transformed))
        if n_unique_labels == 2:
            target_type = "binary"
        else:
            target_type = "classification"
        num_classes = n_unique_labels
        y_transformed = y_transformed.astype(np.int64)
    else:  # Regression
        y_transformed = y_values.astype(np.float32)
        target_type = "regression"
        num_classes = 1
    
    # Create fixed test set (15% of the data)
    indices = np.arange(len(X_transformed))
    try:
        # First try with stratify if it's classification
        if target_type != "regression":
            train_indices, test_indices = train_test_split(
                indices, test_size=test_ratio, random_state=random_seed,
                stratify=y_transformed
            )
        else:
            train_indices, test_indices = train_test_split(
                indices, test_size=test_ratio, random_state=random_seed
            )
    except ValueError as e:
        # Fallback if stratification fails (e.g., too few samples for some class)
        print(f"Warning: Stratification failed, falling back to non-stratified split: {e}")
        train_indices, test_indices = train_test_split(
            indices, test_size=test_ratio, random_state=random_seed
        )
    
    # Calculate optimal segment size
    train_size = len(train_indices)
    segment_size = calculate_optimal_segment_size(train_size, min_segment_size, max_segment_size)
    
    # For small datasets, segment_size could equal train_size
    # In this case, we have exactly 1 segment
    num_segments = max(1, train_size // segment_size)
    
    # Handle the case where there might be some leftover instances
    leftover = train_size % segment_size
    
    if leftover > 0:
        # Create balanced segments by redistributing leftover instances
        # across segments to keep them as uniform as possible
        instances_per_segment = [segment_size] * num_segments
        for i in range(leftover):
            instances_per_segment[i % num_segments] += 1
    else:
        instances_per_segment = [segment_size] * num_segments
        
    print(f"Optimal segment size: {segment_size}, Segments: {num_segments}")
    if leftover > 0:
        print(f"Redistributed {leftover} instances across segments for balance")
        print(f"Segment sizes: {instances_per_segment}")
    
    # Calculate feature information
    if hasattr(preprocessor, 'transformers_'):
        # Get number of features after one-hot encoding
        num_features = X_transformed.shape[1]
        cat_idx = []  # After one-hot encoding, we don't have categorical indices anymore
    else:
        # Fallback
        num_features = X_transformed.shape[1]
        cat_idx = []
    
    # Create a dataset dictionary
    dataset_info = {
        "task_id": task_id,
        "dataset_name": dataset_name,
        "original_name": dataset.name,
        "X": X_transformed,
        "y": y_transformed,
        "cat_idx": cat_idx,
        "feature_names": feature_names,
        "target_type": target_type,
        "num_classes": num_classes,
        "num_features": num_features,
        "num_instances": len(X_transformed),
        "train_indices": train_indices.tolist(),
        "test_indices": test_indices.tolist(),
        "segment_size": segment_size,
        "num_segments": num_segments,
        "instances_per_segment": instances_per_segment,
        "test_ratio": test_ratio,
        "val_ratio": val_ratio
    }
    
    return dataset_info


def save_full_dataset(dataset_info: Dict) -> Path:
    """Save the full dataset with train/test splits."""
    dataset_name = dataset_info["dataset_name"]
    output_dir = FULL_DATASET_DIR / dataset_name
    
    # Create directory using lock to avoid race conditions
    with file_lock:
        output_dir.mkdir(parents=True, exist_ok=True)
    
    # Get data
    X = dataset_info["X"]
    y = dataset_info["y"]
    train_indices = np.array(dataset_info["train_indices"])
    test_indices = np.array(dataset_info["test_indices"])
    
    # Save train/test split
    X_train, y_train = X[train_indices], y[train_indices]
    X_test, y_test = X[test_indices], y[test_indices]
    
    # Save X and y for train set
    with gzip.GzipFile(output_dir / "X_train.npy.gz", "w") as f:
        np.save(f, X_train)
    
    with gzip.GzipFile(output_dir / "y_train.npy.gz", "w") as f:
        np.save(f, y_train)
    
    # Save X and y for test set
    with gzip.GzipFile(output_dir / "X_test.npy.gz", "w") as f:
        np.save(f, X_test)
    
    with gzip.GzipFile(output_dir / "y_test.npy.gz", "w") as f:
        np.save(f, y_test)
    
    # Also save the full dataset for reference
    with gzip.GzipFile(output_dir / "X_full.npy.gz", "w") as f:
        np.save(f, X)
    
    with gzip.GzipFile(output_dir / "y_full.npy.gz", "w") as f:
        np.save(f, y)
    
    # Save indices
    with gzip.GzipFile(output_dir / "train_indices.npy.gz", "w") as f:
        np.save(f, train_indices)
    
    with gzip.GzipFile(output_dir / "test_indices.npy.gz", "w") as f:
        np.save(f, test_indices)
    
    # Save metadata
    metadata = {k: v for k, v in dataset_info.items() 
               if k not in ["X", "y", "train_indices", "test_indices"]}
    metadata.update({
        "train_size": len(train_indices),
        "test_size": len(test_indices)
    })
    
    with open(output_dir / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=4)
    
    print(f"Saved full dataset to {output_dir}")
    return output_dir


def save_segmented_dataset(dataset_info: Dict) -> Path:
    """Save the segmented versions of the dataset with train/val splits per segment."""
    dataset_name = dataset_info["dataset_name"]
    X = dataset_info["X"]
    y = dataset_info["y"]
    
    # Get train indices and segment information
    train_indices = np.array(dataset_info["train_indices"])
    test_indices = np.array(dataset_info["test_indices"])
    instances_per_segment = dataset_info["instances_per_segment"]
    num_segments = dataset_info["num_segments"]
    val_ratio = dataset_info["val_ratio"]
    target_type = dataset_info["target_type"]
    
    output_dir = SEGMENTED_DATASET_DIR / dataset_name
    
    # Create directories using lock to avoid race conditions
    with file_lock:
        output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save global test set
    X_test, y_test = X[test_indices], y[test_indices]
    with gzip.GzipFile(output_dir / "X_test.npy.gz", "w") as f:
        np.save(f, X_test)
    
    with gzip.GzipFile(output_dir / "y_test.npy.gz", "w") as f:
        np.save(f, y_test)
    
    # Save metadata for the whole dataset
    metadata = {k: v for k, v in dataset_info.items() 
               if k not in ["X", "y", "train_indices", "test_indices"]}
    metadata.update({
        "train_size": len(train_indices),
        "test_size": len(test_indices)
    })
    
    with open(output_dir / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=4)
    
    # Create segments and cumulative segments directories
    segments_dir = output_dir / "segments"
    with file_lock:
        segments_dir.mkdir(exist_ok=True)
    
    cumulative_dir = output_dir / "cumulative"
    with file_lock:
        cumulative_dir.mkdir(exist_ok=True)
    
    # Create and save segments
    start_idx = 0
    for seg_idx in range(num_segments):
        seg_size = instances_per_segment[seg_idx]
        end_idx = start_idx + seg_size
        
        # Get segment data
        segment_indices = train_indices[start_idx:end_idx]
        seg_X, seg_y = X[segment_indices], y[segment_indices]
        
        # Create train/validation split for this segment
        if val_ratio > 0:
            try:
                # Try stratified split for classification tasks
                if target_type != "regression":
                    seg_train_indices, seg_val_indices = train_test_split(
                        np.arange(len(seg_X)), 
                        test_size=val_ratio, 
                        random_state=42 + seg_idx,  # Different seed for each segment
                        stratify=seg_y
                    )
                else:
                    seg_train_indices, seg_val_indices = train_test_split(
                        np.arange(len(seg_X)), 
                        test_size=val_ratio, 
                        random_state=42 + seg_idx
                    )
            except ValueError:
                # Fallback if stratification fails
                print(f"Warning: Stratification failed for segment {seg_idx}, using non-stratified split")
                seg_train_indices, seg_val_indices = train_test_split(
                    np.arange(len(seg_X)), 
                    test_size=val_ratio, 
                    random_state=42 + seg_idx
                )
            
            seg_X_train, seg_y_train = seg_X[seg_train_indices], seg_y[seg_train_indices]
            seg_X_val, seg_y_val = seg_X[seg_val_indices], seg_y[seg_val_indices]
        else:
            seg_X_train, seg_y_train = seg_X, seg_y
            seg_X_val, seg_y_val = np.array([]), np.array([])
            seg_train_indices = np.arange(len(seg_X))
            seg_val_indices = np.array([])
        
        # Individual segment directory
        seg_dir = segments_dir / f"segment_{seg_idx}"
        with file_lock:
            seg_dir.mkdir(exist_ok=True)
        
        # Save train data
        with gzip.GzipFile(seg_dir / "X_train.npy.gz", "w") as f:
            np.save(f, seg_X_train)
        
        with gzip.GzipFile(seg_dir / "y_train.npy.gz", "w") as f:
            np.save(f, seg_y_train)
        
        # Save validation data
        with gzip.GzipFile(seg_dir / "X_val.npy.gz", "w") as f:
            np.save(f, seg_X_val)
        
        with gzip.GzipFile(seg_dir / "y_val.npy.gz", "w") as f:
            np.save(f, seg_y_val)
        
        # Save segment metadata
        seg_metadata = {
            "segment_idx": seg_idx,
            "start_idx": start_idx,
            "end_idx": end_idx,
            "segment_size": seg_size,
            "train_size": len(seg_X_train),
            "val_size": len(seg_X_val),
            "global_indices": segment_indices.tolist(),
            "segment_train_indices": seg_train_indices.tolist(),
            "segment_val_indices": seg_val_indices.tolist(),
            "num_features": dataset_info["num_features"]
        }
        with open(seg_dir / "metadata.json", "w") as f:
            json.dump(seg_metadata, f, indent=4)
        
        # Cumulative segment (all data up to this point)
        cum_indices = train_indices[:end_idx]
        cum_X, cum_y = X[cum_indices], y[cum_indices]
        
        # Split cumulative data into train/val
        if val_ratio > 0:
            try:
                # Try stratified split for classification tasks
                if target_type != "regression":
                    cum_train_indices, cum_val_indices = train_test_split(
                        np.arange(len(cum_X)), 
                        test_size=val_ratio, 
                        random_state=42,  # Keep consistent seed for cumulative
                        stratify=cum_y
                    )
                else:
                    cum_train_indices, cum_val_indices = train_test_split(
                        np.arange(len(cum_X)), 
                        test_size=val_ratio, 
                        random_state=42
                    )
            except ValueError:
                # Fallback if stratification fails
                print(f"Warning: Stratification failed for cumulative segment {seg_idx}, using non-stratified split")
                cum_train_indices, cum_val_indices = train_test_split(
                    np.arange(len(cum_X)), 
                    test_size=val_ratio, 
                    random_state=42
                )
            
            cum_X_train, cum_y_train = cum_X[cum_train_indices], cum_y[cum_train_indices]
            cum_X_val, cum_y_val = cum_X[cum_val_indices], cum_y[cum_val_indices]
        else:
            cum_X_train, cum_y_train = cum_X, cum_y
            cum_X_val, cum_y_val = np.array([]), np.array([])
            cum_train_indices = np.arange(len(cum_X))
            cum_val_indices = np.array([])
        
        # Cumulative directory
        cum_dir = cumulative_dir / f"cumulative_{seg_idx}"
        with file_lock:
            cum_dir.mkdir(exist_ok=True)
        
        # Save train data
        with gzip.GzipFile(cum_dir / "X_train.npy.gz", "w") as f:
            np.save(f, cum_X_train)
        
        with gzip.GzipFile(cum_dir / "y_train.npy.gz", "w") as f:
            np.save(f, cum_y_train)
        
        # Save validation data
        with gzip.GzipFile(cum_dir / "X_val.npy.gz", "w") as f:
            np.save(f, cum_X_val)
        
        with gzip.GzipFile(cum_dir / "y_val.npy.gz", "w") as f:
            np.save(f, cum_y_val)
        
        # Save cumulative metadata
        cum_metadata = {
            "segment_idx": seg_idx,
            "num_segments_included": seg_idx + 1,
            "total_instances": len(cum_X),
            "train_size": len(cum_X_train),
            "val_size": len(cum_X_val),
            "global_indices": cum_indices.tolist(),
            "cumulative_train_indices": cum_train_indices.tolist(),
            "cumulative_val_indices": cum_val_indices.tolist(),
            "num_features": dataset_info["num_features"]
        }
        with open(cum_dir / "metadata.json", "w") as f:
            json.dump(cum_metadata, f, indent=4)
        
        # Update start index for next segment
        start_idx = end_idx
    
    print(f"Saved segmented dataset with {num_segments} segments to {output_dir}")
    return output_dir


def process_single_task(task_id: int, min_segment_size: int, max_segment_size: int, 
                       test_ratio: float, val_ratio: float, process_id: int) -> Dict:
    """Process a single OpenML task. This function will be called by each worker process."""
    try:
        print(f"Process {process_id}: Starting to process task {task_id}")
        start_time = time.time()
        
        # Preprocess the dataset
        dataset_info = preprocess_dataset(
            task_id=task_id,
            min_segment_size=min_segment_size,
            max_segment_size=max_segment_size,
            test_ratio=test_ratio,
            val_ratio=val_ratio
        )
        
        # Save the full dataset
        full_dir = save_full_dataset(dataset_info)
        
        # Save the segmented dataset
        seg_dir = save_segmented_dataset(dataset_info)
        
        end_time = time.time()
        processing_time = end_time - start_time
        
        result = {
            "task_id": task_id,
            "status": "success",
            "dataset_name": dataset_info["dataset_name"],
            "original_name": dataset_info["original_name"],
            "num_features": dataset_info["num_features"],
            "num_instances": dataset_info["num_instances"],
            "num_segments": dataset_info["num_segments"],
            "segment_size": dataset_info["segment_size"],
            "instances_per_segment": dataset_info["instances_per_segment"],
            "target_type": dataset_info["target_type"],
            "full_dir": str(full_dir),
            "segmented_dir": str(seg_dir),
            "processing_time": processing_time
        }
        
        print(f"Process {process_id}: Completed task {task_id} in {processing_time:.2f} seconds")
        return result
    
    except Exception as e:
        import traceback
        error_message = f"Error processing task {task_id}: {str(e)}\n{traceback.format_exc()}"
        print(error_message)
        return {"task_id": task_id, "status": "failed", "error": error_message}


def update_results_file(result: Dict, results_file: str = "processed_datasets_summary.json"):
    """Update the results file with a new result."""
    with file_lock:
        # Read existing results
        if os.path.exists(results_file):
            try:
                with open(results_file, 'r') as f:
                    results = json.load(f)
            except:
                results = {}
        else:
            results = {}
        
        # Update with new result
        task_id = str(result["task_id"])  # Convert to string for JSON keys
        results[task_id] = result
        
        # Write back to file
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=4)


def update_failed_tasks_file(task_id: int, error: str, file_path: str = "failed_tasks.txt"):
    """Update the failed tasks file with a new failed task."""
    with file_lock:
        with open(file_path, 'a') as f:
            f.write(f"{error}\n\n")


def process_all_datasets(task_list_file: str, min_segment_size: int = 500, 
                        max_segment_size: int = 1000, test_ratio: float = 0.15,
                        val_ratio: float = 0.15, max_tasks: Optional[int] = None,
                        num_workers: int = 4) -> Dict:
    """Process all datasets from the task list using multiprocessing."""
    # Create base directories
    FULL_DATASET_DIR.mkdir(parents=True, exist_ok=True)
    SEGMENTED_DATASET_DIR.mkdir(parents=True, exist_ok=True)
    
    # Read task IDs
    task_ids = read_task_list(task_list_file)
    if max_tasks:
        task_ids = task_ids[:max_tasks]
    
    print(f"Processing {len(task_ids)} tasks from {task_list_file} using {num_workers} worker processes")
    
    # Initialize empty results and failed tasks files
    with open("processed_datasets_summary.json", "w") as f:
        json.dump({}, f)
    
    with open("failed_tasks.txt", "w") as f:
        f.write("# Failed Tasks Log - Contains task IDs and detailed error information\n")
        f.write("# Generated on: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n")
        f.write("="*80 + "\n\n")
    
    # Set up the worker pool
    num_workers = min(num_workers, len(task_ids))  # Don't create more workers than tasks
    
    # Create partial function with fixed parameters
    process_task_partial = partial(
        process_single_task,
        min_segment_size=min_segment_size,
        max_segment_size=max_segment_size,
        test_ratio=test_ratio,
        val_ratio=val_ratio
    )
    
    # Create a manager queue for task assignments and process tracking
    manager = mp.Manager()
    task_queue = manager.Queue()
    
    # Fill the queue with tasks
    for i, task_id in enumerate(task_ids):
        task_queue.put((task_id, i))
    
    # Define a worker function that takes tasks from the queue
    def worker(worker_id, task_queue, results):
        while not task_queue.empty():
            try:
                task_id, task_index = task_queue.get(block=False)
                print(f"Process {worker_id}: Processing task {task_index+1}/{len(task_ids)}: ID={task_id}")
                result = process_task_partial(task_id=task_id, process_id=worker_id)
                
                # Update the shared results
                if result["status"] == "success":
                    update_results_file(result)
                    results.append(1)  # Count successful tasks
                else:
                    update_failed_tasks_file(task_id, result.get("error", "Unknown error"))
                    results.append(0)  # Count failed tasks
            except Exception as e:
                print(f"Process {worker_id}: Error fetching or processing task: {e}")
                break
    
    # Create shared list for results
    results = manager.list()
    
    # Start worker processes
    processes = []
    start_time = time.time()
    
    try:
        for i in range(num_workers):
            p = mp.Process(target=worker, args=(i, task_queue, results))
            processes.append(p)
            p.start()
        
        # Wait for all processes to complete
        for p in processes:
            p.join()
        
    except KeyboardInterrupt:
        print("Keyboard interrupt detected. Terminating worker processes...")
        for p in processes:
            if p.is_alive():
                p.terminate()
    
    end_time = time.time()
    total_time = end_time - start_time
    
    # Load the final results
    try:
        with open("processed_datasets_summary.json", "r") as f:
            final_results = json.load(f)
    except:
        final_results = {}
    
    # Count successful and failed tasks
    success_count = sum(results)
    failed_count = len(task_ids) - success_count
    
    print(f"\nProcessing complete in {total_time:.2f} seconds")
    print(f"Successfully processed {success_count} datasets")
    print(f"Failed to process {failed_count} datasets")
    
    if failed_count > 0:
        print("See failed_tasks.txt for details on failed tasks")
    
    return final_results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="OpenML Dataset Processor with Multiprocessing")
    parser.add_argument("--task_list", type=str, default="openml_import.txt",
                        help="File containing OpenML task IDs")
    parser.add_argument("--min_segment_size", type=int, default=500,
                        help="Minimum size of data segments (default: 500)")
    parser.add_argument("--max_segment_size", type=int, default=1000,
                        help="Maximum size of data segments (default: 1000)")
    parser.add_argument("--test_ratio", type=float, default=0.15,
                        help="Ratio of data for test set (default: 0.15)")
    parser.add_argument("--val_ratio", type=float, default=0.15,
                        help="Ratio of data for validation in each segment (default: 0.15)")
    parser.add_argument("--max_tasks", type=int, default=None,
                        help="Maximum number of tasks to process from the list")
    parser.add_argument("--num_workers", type=int, default=mp.cpu_count(),
                        help=f"Number of worker processes (default: {mp.cpu_count()}, the number of CPU cores)")
    
    args = parser.parse_args()
    
    # Configure OpenMP threads to avoid oversubscription
    # Each worker will use 1 thread to avoid competing for resources
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
    os.environ["NUMEXPR_NUM_THREADS"] = "1"
    
    # Process datasets with multiprocessing
    process_all_datasets(
        args.task_list, 
        min_segment_size=args.min_segment_size, 
        max_segment_size=args.max_segment_size,
        test_ratio=args.test_ratio,
        val_ratio=args.val_ratio,
        max_tasks=args.max_tasks,
        num_workers=args.num_workers
    )