import os
import torch
import numpy as np
import gymnasium as gym

import multiprocessing as mp
from itertools import product

from agents.be_ddqn import BE_DDQN
from agents.be_ppo import BE_PPO
from agents.be_mpc_ddqn import BE_MPC_DDQN
from agents.be_mpc_ppo import BE_MPC_PPO


def run_single_configuration(config):
    """
    Run a single experimental configuration with the given parameters.
    This function will be called by each process.

    Args:
        config (dict): Configuration dictionary containing all necessary parameters
    """
    # Extract parameters from config
    agent_type = config['agent_type']
    seed = config['seed']
    noise_level = config['noise_level']
    uncert_method = config['uncert_method']
    max_steps = config['max_steps']
    num_eps = config['num_eps']
    noise_model = config['noise_model']

    wandb_flag = False

    # Set up environment
    env = gym.make("MountainCar-v0")

    # Determine device - for parallel processing, we'll use a different GPU if available
    # or fall back to CPU if no GPU is available
    if torch.cuda.is_available():
        # Get the number of available GPUs
        num_gpus = torch.cuda.device_count()
        # Use current process ID to determine which GPU to use
        gpu_id = (os.getpid() % num_gpus) if num_gpus > 0 else 0
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cpu")

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    env.reset(seed=seed)

    # Initialize the agent
    if agent_type == 'BE_DDQN':
        agent = BE_DDQN(env=env, seed=seed, device=device, uncertainty_method=uncert_method, policy='DDQN',
                        noise_std=noise_level, noise_model=noise_model, wandb_flag=wandb_flag)
    elif agent_type == 'BE_MPC_DDQN':
        agent = BE_MPC_DDQN(env=env, seed=seed, device=device, uncertainty_method=uncert_method, policy='DDQN',
                            noise_std=noise_level, noise_model=noise_model, wandb_flag=wandb_flag)
    elif agent_type == 'BE_PPO':
        agent = BE_PPO(env=env, seed=seed, device=device, uncertainty_method=uncert_method, policy='PPO',
                       noise_std=noise_level, noise_model=noise_model, wandb_flag=wandb_flag)
    elif agent_type == 'BE_MPC_PPO':
        agent = BE_MPC_PPO(env=env, seed=seed, device=device, uncertainty_method=uncert_method, policy='PPO',
                           noise_std=noise_level, noise_model=noise_model, wandb_flag=wandb_flag)
    else:
        agent = BE_DDQN(env=env, seed=seed, device=device, uncertainty_method=uncert_method, policy='Random',
                        noise_std=noise_level, noise_model=noise_model, wandb_flag=wandb_flag)

    # Run the analysis
    perc_states, solved_step = agent.train(max_steps, num_eps, diffs=True)

    # Return results
    return {
        'seed': seed,
        'uncert_method': uncert_method,
        'noise_level': noise_level,
        'noise_model': noise_model,
        'perc_states': perc_states,
        'solved_step': solved_step
    }


def main():
    # Define experiment parameters
    agent_type = 'BE_MPC_PPO'  # ['BE_PPO', 'BE_MPC_PPO', 'Random']
    seeds = [x * 2 for x in range(0, 10)]  # Example: 5 different seeds
    if agent_type in ['BE_MPC_DDQN', 'BE_MPC_PPO']:
        uncert_methods = ["Entropy", "IG"]
    else:
        uncert_methods = ["Error", "Entropy", "IG"]
    noise_levels = [0.01]
    noise_models = ['homoskedastic', 'heteroskedastic']

    # Create all possible combinations of parameters
    configurations = [
        {
            'agent_type': agent_type,
            'seed': seed,
            'noise_level': noise_level,
            'uncert_method': uncert_method,
            'max_steps': 1_000,
            'num_eps': 1,
            'noise_model': noise_model
        }
        for seed, noise_level, uncert_method, noise_model in product(seeds, noise_levels, uncert_methods, noise_models)
    ]

    # Initialize multiprocessing
    num_processes = min(len(configurations), mp.cpu_count())
    print(f"Running experiments using {num_processes} processes")

    # Create a pool of workers
    with mp.Pool(processes=num_processes) as pool:
        # Run experiments in parallel
        results = pool.map(run_single_configuration, configurations)

    # Process and save results
    for result in results:
        seed = result['seed']
        uncert_method = result['uncert_method']
        noise_level = result['noise_level']
        noise_model = result['noise_model']

        print(f"Completed run for {uncert_method}, noise level {noise_level}, noise model {noise_model}, seed {seed}")
        print(f"Solved at step: {result['solved_step'] if result['solved_step'] is not None else 'Not solved'}")


if __name__ == "__main__":
    # This ensures multiprocessing only runs in the main script
    mp.set_start_method('spawn', force=True)  # Important for CUDA compatibility
    main()
