from multiprocessing.dummy import Pool # use threads
from subprocess import Popen
import subprocess
import os
import numpy


def launch_training_runs(env_list: list = [], num_seeds_per_run: int = 1, gpu_list: list = [], parallel_workers: int = 1):
    
    architectures = ["Masksemble", "Ensemble", "Dropout", "Baseline"]
    
    random_seed_list = [42, 237467234, 1238123123, 1230000, 9102930123, 63125456]
    
    arch_policy_map = {
        "Image":
        {
            "Masksemble": "MasksemblesCnnPolicy",
            "Ensemble": "EnsembleCnnPolicy",
            "Dropout": "DropoutCnnPolicy",
            "Baseline": "CnnPolicy"
        },
        "Vector":
        {
            "Masksemble": "MasksemblesMlpPolicy",
            "Ensemble": "EnsembleMlpPolicy",
            "Dropout": "DropoutMlpPolicy",
            "Baseline": "MlpPolicy"            
        }
    }
    
    def return_arch_policy_kwargs(args_env):
        # Custom Parameters per environment
        if args_env in ["HalfCheetah-v3"]:
            arch_policy_policy_kwargs = {
                    "Masksemble": "dict(log_std_init=-2,ortho_init=False,net_arch=dict(hidden_dim=384,last_layer_dim_pi=384,last_layer_dim_vf=384,num_masks=4))",
                    "Ensemble": "dict(log_std_init=-2,ortho_init=False,net_arch=dict(hidden_dim=256,last_layer_dim_pi=256,last_layer_dim_vf=256,num_models=4))",
                    "Dropout": "dict(log_std_init=-2,ortho_init=False,net_arch=dict(hidden_dim=256,last_layer_dim_pi=256,last_layer_dim_vf=256,dropout_p=0.2))",
                    "Baseline": "dict(log_std_init=-2, ortho_init=False, net_arch=[dict(pi=[256, 256], vf=[256, 256])])"
            }
        elif args_env in ["Ant-v3"]:
            arch_policy_policy_kwargs = {
                    "Masksemble": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,num_masks=4))",
                    "Ensemble": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,num_models=4))",
                    "Dropout": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,dropout_p=0.2))",
                    "Baseline": "dict(net_arch=[dict(pi=[128, 128], vf=[128, 128])])"
            }
        elif args_env in ["Walker2d-v3"]:
            arch_policy_policy_kwargs = {
                    "Masksemble": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,num_masks=4))",
                    "Ensemble": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,num_models=4))",
                    "Dropout": "dict(net_arch=dict(hidden_dim=128,last_layer_dim_pi=128,last_layer_dim_vf=128,dropout_p=0.2))",
                    "Baseline": "dict(net_arch=[dict(pi=[128, 128], vf=[128, 128])])"
            }            
        elif args_env in ["SeaquestNoFrameskip-v4", "MsPacmanNoFrameskip-v4", "BerzerkNoFrameskip-v4"]:
            arch_policy_policy_kwargs = {
                    "Masksemble": "dict(features_extractor_kwargs=dict(num_masks=4))",
                    "Ensemble": "dict(features_extractor_kwargs=dict(num_models=4), num_models=4)",
                    "Dropout": "dict(features_extractor_kwargs=dict(dropout_p=0.2))",
                    "Baseline": None
            } 
        return arch_policy_policy_kwargs
        
    available_gpus = gpu_list
    
    def work(args):
        env, architecture, arch_policy_type, available_gpu, log_filename, random_seed = args
        env_vars = os.environ.copy()
        env_vars["CUDA_VISIBLE_DEVICES"] = str(available_gpu)
        try:
            if return_arch_policy_kwargs(env)[architecture] is not None:
                print("policy_kwargs:"+return_arch_policy_kwargs(env)[architecture])
            with open(os.path.join(log_filename,"stdout.log"), 'wb', 0) as logfile:
                # add "--non-deterministic-eval" for atari
                command = ["python", "train.py", "--algo",  "ppo", "--env", env, "--log-folder", os.path.join("../neurips_benchmarks", env, architecture), "--tensorboard-log", os.path.join("../neurips_benchmarks", env, architecture), "--seed", str(random_seed), "--no-video-callback", "--hyperparams", "policy:\""+arch_policy_map[arch_policy_type][architecture]+"\""]
                if return_arch_policy_kwargs(env)[architecture] is not None:
                    command += ["policy_kwargs:"+return_arch_policy_kwargs(env)[architecture]]
                
                p = Popen(command, stdout=logfile, stderr=logfile, env=env_vars)
            return p.wait(), None
        except Exception as e:
            return None, str(e)

    num = parallel_workers  # set to the number of workers you want (it defaults to the cpu count of your machine)

    current_gpu_index = 0    
    
    pool = Pool(num) # 4 concurrent commands at a time
    
    os.makedirs("../neurips_benchmarks", exist_ok=True)
     
    launches = []
    
    for seed_idx in range(num_seeds_per_run):
        
        random_seed = random_seed_list[seed_idx]
    
        for env in env_list:

            for architecture in architectures:
                # Launch Training Runs
                if "Frameskip" in env:
                    arch_policy_type = "Image"
                else:
                    arch_policy_type = "Vector"

                available_gpu = available_gpus[current_gpu_index]
                current_gpu_index = (current_gpu_index + 1) % len(available_gpus)

                logfile = os.path.join("../neurips_benchmarks", env, architecture)
                os.makedirs(logfile, exist_ok=True)

                launches.append((env, architecture, arch_policy_type, available_gpu, logfile, random_seed ))

        print("launches", launches, "random_seed", random_seed)

    for status, error in pool.imap_unordered(work, launches):
        if error is None:
            fmt = '{status} done, status {status}'
        else:
            fmt = 'failed to run {status}, reason: {error}'
        print(fmt.format_map(vars())) # or fmt.format(**vars()) on older versions


if __name__ == "__main__":
    
    # 1. Launch Training 
    launch_training_runs(["Ant-v3", "HalfCheetah-v3", "Walker2d-v3", "MsPacmanNoFrameskip-v4", "SeaquestNoFrameskip-v4", "BerzerkNoFrameskip-v4"], 
    num_seeds_per_run=3,
    gpu_list=[0],
    parallel_workers=1)
