import random
import os
from absl import logging, flags, app
from multiprocessing import Queue, Manager
from pathos import multiprocessing
import traceback
import time

which_gpus = [1]
max_worker_num = 2


seeds = [1]
# exps = [
#     {
#         "snapshot_dir": f"data/local/experiment/dnc_sac_MetaWorldMT10-v2_distilled_beta0001_{s}",
#         "seed": s,
#     }
#     for s in seeds
# ]

exps = [
    {
        "snapshot_dir": f"data/local/experiment/sac_Walker2dMT4-v0_shared_lr0.0003_{s}",
        "seed": s
    }
    for s in seeds
]

sharing_quantiles = [0, 0.8]

# exps = [
#     {
#         "snapshot_dir": f"data/local/experiment/cds_MetaWorldMT10-v2_unsupervised_sharingquantile_{q}_{s}",
#         "seed": s
#     }
#     for q in sharing_quantiles
#     for s in seeds
# ]

# exps = [
#     {
#         "snapshot_dir": f"data/local/experiment/cds_Walker2dMT4-v0_unsupervised_sharingquantile_{q}_{s}",
#         "seed": s
#     }
#     for q in sharing_quantiles
#     for s in seeds
# ]


def _init_device_queue(max_worker_num):
    m = Manager()
    device_queue = m.Queue()
    for i in range(max_worker_num):
        idx = i % len(which_gpus)
        gpu = which_gpus[idx]
        device_queue.put(gpu)
    return device_queue


def run():
    """Run trainings with all possible parameter combinations in
    the configured space.
    """

    process_pool = multiprocessing.Pool(processes=max_worker_num, maxtasksperchild=1)
    device_queue = _init_device_queue(max_worker_num)

    i = 0
    for exp in exps:
        gpu = which_gpus[i % len(which_gpus)]
        mujoco_gpus = {0 : 2, 1 : 3, 2 : 1, 3 : 0, 4 : 6, 5 : 7, 6: 5, 7 : 4}
        command = "GPUS={} python run.py resume --gpu=0".format(
            mujoco_gpus[gpu], exp
        )
        for param in exp.keys():
            command += " --{} {}".format(param, exp[param])
        print(command)
        process_pool.apply_async(
            func=_worker,
            args=[command, device_queue],
            error_callback=lambda e: logging.error(e),
        )
        i += 1
    process_pool.close()
    process_pool.join()


def _worker(command, device_queue):
    # sleep for random seconds to avoid crowded launching
    try:
        time.sleep(random.uniform(0, 10))

        device = device_queue.get()

        logging.set_verbosity(logging.INFO)

        logging.info("command %s" % command)
        os.system("CUDA_VISIBLE_DEVICES=%d " % device + command)

        device_queue.put(device)
    except Exception as e:
        logging.info(traceback.format_exc())
        raise e


run()
