import os
import argparse
import traceback
from time import sleep
from concurrent.futures import ProcessPoolExecutor as Pool


def run(map_name, seed=0, n_expert_episodes=1, algos=[], train_expert=False):
    """
    Run the experiments with the given map name and a list of algorithms
    :param map_name: str
    :param seed: int
    :param n_expert_episodes: int
    :param algos: list
    :param train_expert: bool
    :return: None
    """

    print("map_name:", map_name)
    if "academy" in map_name:
        from trainer.grf_env import GRFWrapper
        env = GRFWrapper(map_name, seed)
    elif "miner" in map_name:
        from trainer.miner_env import MinerWrapper
        _, mode, n_agents, _, n_enemies = map_name.split("_")
        config = {"mode": mode, "n_agents": int(n_agents), "n_enemies": int(n_enemies)}
        env = MinerWrapper(config, seed)
    elif "simple" in map_name:
        from trainer.mpe_env import MPEWrapper
        env = MPEWrapper(map_name, seed)
    else:
        from trainer.utils import read_map_info
        from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper as StarCraft2Env
        env = StarCraft2Env(**read_map_info(map_name))
        
    import torch
    import setproctitle
    from trainer.runner import Runner_QMIX_SMAC
    torch.set_num_threads(1)
    if args.device != "cpu":
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    try:
        if train_expert:
            runner = Runner_QMIX_SMAC(args, map_name, env, seed, n_expert_episodes, "qmix")
            runner.init(use_expert=False)
            runner.run()
        else:
            max_experts = 10000
            if "simple" not in map_name:
                ex_buffer = Runner_QMIX_SMAC(args, map_name, env, seed, n_expert_episodes).collect_expert_buffer(max_experts, verbose=True)
            else:
                max_experts = 1000
                ex_buffer = Runner_QMIX_SMAC(args, map_name, env, seed, n_expert_episodes).collect_expert_buffer(max_experts, verbose=True, deterministic=False)
            for algo in algos:
                setproctitle.setproctitle(f"{map_name}-{n_expert_episodes}-{algo}-seed{seed}")
                runner = Runner_QMIX_SMAC(args, map_name, env, seed, n_expert_episodes, algo)
                runner.init(max_experts, ex_buffer)
                runner.run()
    except:
        print(traceback.print_exc())
        

def main(key=None):
    """
    Run the main function with the given key matching the map name
    :param key: str
    :return: None
    """
    map_names = {
        "protoss": ["5_vs_5", "10_vs_10"],
        "terran": ["5_vs_5", "10_vs_10"],
        "zerg": ["5_vs_5", "10_vs_10"]
    }
    map_names = [f"{map}_{name}" for map, names in map_names.items() for name in names]
    map_names += [
        "miner_easy_2_vs_2",
        "miner_medium_2_vs_2",
        "miner_hard_2_vs_2",
        "simple_speaker_listener",
        "simple_spread",
        "simple_reference",
    ]
    
    # In the following, we set the number of workers for each environment
    # Increase the number of workers for the environment that requires more computation
    # We tested on a machine with 100GB RAM and 1 NVIDIA V100 GPU
    if "miner" in key:
        n_workers = 6
    elif "simple" in key:
        n_workers = 12
    else:
        n_workers = 3

    with Pool(n_workers) as p:
        tasks = []
        for seed in [0, 1, 2, 3]:
            if "simple" in key:
                episodes = [128, 64, 32, 16, 8, 4, 2, 1]
            else:
                episodes = [4096, 128, 2048, 256, 1024, 512]
            for n_expert_episodes in episodes:
                for map_name in map_names:
                    if key not in map_name:
                        continue
                    task = p.submit(run, map_name, seed, n_expert_episodes, ["mifq-soft", "mifq-dqn", "sqil", "gail", "iiq", "iqvdn", "airl", "bc"])
                    # task = p.submit(run, map_name, seed, n_expert_episodes, ["mifq-soft-sigmoid", "mifq-soft-tanh", "mifq-dqn-sigmoid", "mifq-dqn-tanh"])
                    sleep(1)
                    tasks.append(task)

        for task in tasks:
            task.result()


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")
    parser.add_argument("--task", type=str, default="protoss", help="Task name")
    parser.add_argument("--max_train_steps", type=int, default=1e6, help="Maximum number of training steps")
    parser.add_argument("--evaluate_freq", type=int, default=10000, help="Evaluate the policy every 'evaluate_freq' steps")
    parser.add_argument("--evaluate_times", type=int, default=32, help="Evaluate times")
    parser.add_argument("--save_freq", type=int, default=1e5, help="Save frequency")
    parser.add_argument("--device", type=str, default="cuda", help="Device")
    parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size (the number of episodes)")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    parser.add_argument("--target_update_freq", type=int, default=4, help="Update frequency of the target network")
    parser.add_argument("--tau", type=int, default=0.1, help="If use soft update")
    parser.add_argument("--sc2_path", type=str, default="/common/home/users/t/tvbui/StarCraftII/", help="SC2 path")
    args = parser.parse_args()
    os.environ["SC2PATH"] = args.sc2_path
    main(args.task)