# inital idea from:
# https://stackoverflow.com/questions/53422761/distributing-jobs-evenly-across-multiple-gpus-with-multiprocessing-pool
import time
import argparse
import shlex
import glob
import itertools
from multiprocessing import Pool, current_process, Queue
from subprocess import Popen


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--command",
        type=str,
        default="python train.py",
        help="the command to run",
    )
    parser.add_argument(
        "--configs_dirs",
        nargs="+",
        default=["configs"],
        help="dirs with configs to run",
    )
    parser.add_argument(
        "--num_seeds", type=int, default=4, help="the number of random seeds"
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=4,
        help="the number of gpus to run benchmark on",
    )
    parser.add_argument('--dry-run', action='store_true')

    args = parser.parse_args()
    return args


def init_worker(shared_queue: Queue):
    global queue
    queue = shared_queue


def run_command(command: str):
    print(f"Running command: {command}")
    with Popen(shlex.split(command)) as process:
        return_code = process.wait()
        assert return_code == 0


def run_process(base_command: str):
    global queue

    gpu_id = queue.get()
    try:
        ident = current_process().ident
        print('{}: starting process on GPU {}'.format(ident, gpu_id))
        command = base_command + f" --device cuda:{gpu_id}"
        run_command(command)
        print('{}: finished'.format(ident))
    finally:
        queue.put(gpu_id)


def _download_all_datasets():
    import gym
    import d4rl

    for name in [
        "halfcheetah-medium-v2",
        "halfcheetah-medium-replay-v2", 
        "halfcheetah-medium-expert-v2",
        "walker2d-medium-v2",
        "walker2d-medium-replay-v2", 
        "walker2d-medium-expert-v2",
        "hopper-medium-v2",
        "hopper-medium-replay-v2",
        "hopper-medium-expert-v2"
        ]:
        env = gym.make(name).get_dataset()


if __name__ == '__main__':
    # generate all commands to run
    args = parse_args()
    seeds = list(range(0, args.num_seeds))

    commands = []
    for configs_path, seed in itertools.product(args.configs_dirs, seeds):
        all_configs = glob.glob(f"{configs_path}/*.yaml")

        for config_path in all_configs:
            command = [args.command, "--config_path", config_path, "--train_seed", str(seed)]
            commands.append(" ".join(command))

    # run all commands on all available gpus
    if args.dry_run:
        print("All commands to be executed:")
        print(*commands, sep="\n")        
    else:
        # pre-download datasets (optional)
        _download_all_datasets()

        shared_queue = Queue()
        # initialize the queue with the GPU ids
        for gpu_ids in range(args.num_gpus):
            shared_queue.put(gpu_ids)

        pool = Pool(processes=args.num_gpus, initializer=init_worker, initargs=(shared_queue,))

        for _ in pool.imap_unordered(run_process, commands):
            pass

        pool.close()
        pool.join()


