import argparse
import json
import shlex
import subprocess
from os.path import join
from os import cpu_count


def run_experiment(command: str):
    command_list = shlex.split(command)
    print(f"running {command}")
    print('cpu count', cpu_count())
    # fd = subprocess.Popen(command_list)
    # return_code = fd.wait()
    # assert return_code == 0
    subprocess.run(command_list)


if __name__ == '__main__':
    argp = argparse.ArgumentParser()
    argp.add_argument('--command', type=str, required=True)
    argp.add_argument('--logging-dir', type=str, required=True)
    argp.add_argument('--params', type=json.loads, required=True)
    argp.add_argument('--num-seeds', type=int, required=True)
    argp.add_argument('--num-workers', type=int, required=True)
    args = argp.parse_args()
    # Generating commands
    ## Detecting parameters with list inputs
    range_keys = []
    val_keys = {}
    for (key, val) in args.params.items():
        if isinstance(val, list):
            range_keys.append(key)
        else:
            val_keys[key] = val
    ## Recursively generating all combination commands
    commands = [" ".join([args.command] + [str(t) for (k, v) in val_keys.items() for t in (k, v)])]  # starting with fixed params
    logging_paths = [args.logging_dir]
    for key in range_keys:
        tempcmds = commands.copy()
        commands = []
        temppaths = logging_paths.copy()
        logging_paths = []
        for val in args.params[key]:
            for (command, path) in zip(tempcmds, temppaths):
                commands.append(f'{command} {key} {val}')
                if key == '--env-name':
                    logging_paths.append(join(path, str(val)))
                else:
                    logging_paths.append(join(path, key[2:] + str(val)))
    tempcmds = commands.copy()
    commands = []
    for s in range(args.num_seeds):
        for (command, path) in zip(tempcmds, logging_paths):
            p = join(path, f'seed_{s}')
            commands.append(f'{command} --seed {s} --logging-path {p}')
    ## Create logging paths using the params in range as sub-folders

    ## Launching commands
    print('Launching all the following commands')
    for command in commands:
        print(command)
    if args.num_workers > 0:
    #     from concurrent.futures import ThreadPoolExecutor
    #     executor = ThreadPoolExecutor(max_workers=args.num_workers, thread_name_prefix="benchmark-worker-")
    #     # from concurrent.futures import ProcessPoolExecutor
    #     # executor = ProcessPoolExecutor(max_workers=args.num_workers)
    #     for command in commands:
    #         executor.submit(run_experiment, command)
    #     executor.shutdown(wait=True)
        from multiprocessing import Process
        ps = []
        for command in commands:
            p = Process(target=run_experiment, args=(command,))
            p.start()
            ps.append(p)
        for p in ps:
            p.join()
    else:
        print("not running the experiments because --workers is set to 0; just printing the commands to run")
