import os
import numpy as np
import torch
import gymnasium as gym

from maze_envs.agents.be_mpc_sac import BE_MPC_SAC


def run_single_configuration(config):
    """
    Run a single experimental configuration with the given parameters.
    This function will be called by each process.

    Args:
        config (dict): Configuration dictionary containing all necessary parameters
    """
    # Extract parameters from config
    seed = config['seed']
    noise_level = config['noise_level']
    uncert_method = config['uncert_method']
    max_steps = config['max_steps']
    num_eps = config['num_eps']
    horizon = config['horizon']
    num_rollouts = config['num_rollouts']

    # Flags
    maze_env = config['maze_env']
    maze_structure = config['maze_structure']
    mpc_flag = config['mpc_flag']
    be_flag = config['be_flag']
    wandb_flag = config['wandb_flag']

    env = gym.make(maze_env,
                   maze_map=maze_structure,
                   max_episode_steps=max_steps,
                   continuing_task=False,
                   reset_target=False)

    # Determine device - for parallel processing, we'll use a different GPU if available
    # or fall back to CPU if no GPU is available
    if torch.cuda.is_available():
        # Get the number of available GPUs
        num_gpus = torch.cuda.device_count()
        # Use current process ID to determine which GPU to use
        gpu_id = (os.getpid() % num_gpus) if num_gpus > 0 else 0
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cpu")

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    env.reset(seed=seed)

    # Initialize the agent
    agent = BE_MPC_SAC(
        env=env,
        seed=seed,
        device=device,
        uncertainty_method=uncert_method,
        noise_std=noise_level,
        mpc=mpc_flag,
        bayes_exp=be_flag,
        wandb_flag=wandb_flag,
        horizon=horizon,
        num_rollouts=num_rollouts
    )

    # Run the analysis
    perc_states, solved_step = agent.train(max_steps, num_eps, diffs=True)

    # Return results
    return {
        'seed': seed,
        'uncert_method': uncert_method,
        'noise_level': noise_level,
        'perc_states': perc_states,
        'solved_step': solved_step
    }