import multiprocessing as mp
from itertools import product

from maze_envs.utils.utils import run_single_configuration


def main():
    # Define experiment parameters
    seeds = [x * 2 for x in range(10, 20)]  # Example: 5 different seeds
    uncert_methods = ['IG']
    noise_levels = [0.01]

    maze_env = 'PointMaze_Medium-v3'
    maze_str = [[1, 1, 1, 1, 1, 1, 1],
                [1, 0, 0, 1, 0, 'g', 1],
                [1, 0, 0, 0, 0, 0, 1],
                [1, 1, 0, 1, 0, 1, 1],
                [1, 0, 0, 0, 0, 0, 1],
                [1, 'r', 1, 0, 0, 0, 1],
                [1, 1, 1, 1, 1, 1, 1]]

    # Create all possible combinations of parameters
    configurations = [
        {
            'seed': seed,
            'noise_level': noise_level,
            'uncert_method': uncert_method,
            'max_steps': 8_000,
            'num_eps': 1,
            'maze_structure': maze_str,
            'maze_env': maze_env,
            'mpc_flag': True,
            'be_flag': True,
            'wandb_flag': True,
            'horizon': 64,
            'num_rollouts': 16
        }
        for seed, noise_level, uncert_method in product(seeds, noise_levels, uncert_methods)
    ]

    # Initialize multiprocessing
    num_processes = min(len(configurations), mp.cpu_count())
    print(f"Running experiments using {num_processes} processes")

    # Create a pool of workers
    with mp.Pool(processes=num_processes) as pool:
        # Run experiments in parallel
        results = pool.map(run_single_configuration, configurations)

    # Process and save results
    for result in results:
        seed = result['seed']
        uncert_method = result['uncert_method']
        noise_level = result['noise_level']

        print(f"Completed run for {uncert_method}, noise level {noise_level}, seed {seed}")
        print(f"Solved at step: {result['solved_step'] if result['solved_step'] is not None else 'Not solved'}")


if __name__ == "__main__":
    # This ensures multiprocessing only runs in the main script
    mp.set_start_method('spawn', force=True)  # Important for CUDA compatibility
    main()
