import os
import sys
import argparse
import ruamel.yaml
import itertools
import subprocess
import time

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--pll_procs", type=int, nargs='+')
parser.add_argument("--devices", type=int, nargs='+')

datasets = ["community_small", "ego_small", "grid", "qm9", "zinc250k"]

args = parser.parse_args(sys.argv[1:])

##
add_diff_steps = [0]
burnins = [0]
rounding_schemes = ["none"]
solve_orders = ["cpj"]
seeds = [42]#, 123, 567]
constraint_config = "none"

yaml = ruamel.yaml.YAML()

##
## Parallel subprocesses 
##
device_idx = 0
processes = {i: [] for i in range(len(args.devices))}
pids_to_device = {}
pids_to_procsid = {}

for dataset in datasets: 
    for seed in seeds:
        for solve_order in solve_orders:
                for burnin in burnins:
                    for add_diff_step in add_diff_steps:
                        for rounding in rounding_schemes:
                            dataset_config= f"sample_{dataset}"
                            constr_config = yaml.load(open(f"config/constraints/{constraint_config}.yaml", 'r'))
                            if "method" not in constr_config:
                                constr_config["method"] = {'op':'proj', 'gamma': 0}
                            constr_config["constraint"] = "None"
                            constr_config["schedule"] = {'gamma': "none", 'params': [0]}
                            # print (solve_order, add_diff_step, rounding, gamma, param_vals)
                            # continue
                            # with open(f"main_{dataset}_{args.constraint}.out", "a") as f:
                            #     f.write(f"{device_idx} ({add_diff_step}, {rounding}, {gamma}, {param_vals}) started\n")
                            # print (f"{device_idx} ({add_diff_step}, {rounding}, {gamma}, {param_vals}) started\n")
                            print (dataset, seed, solve_order, burnin, add_diff_step, rounding, flush=True)
                            constr_config["add_diff_step"] = add_diff_step
                            constr_config["burnin"] = burnin
                            constr_config["rounding"] = rounding
                            constr_config["method"]["solve_order"] = solve_order
                            constr_config['params'] = []
                            yaml.dump(constr_config, open(f"config/constraints/{constraint_config}.yaml", "w"))
                            device = 'cuda:' + str(args.devices[device_idx])
                            proc = subprocess.Popen(["python", "main.py", "--type", "sample", "--config", dataset_config,
                                                    "--constr_config", constraint_config, "--device", device, 
                                                    "--seed", str(seed)])
                                                    # stdout=open(f"main_{dataset}_{args.constraint}.out", "a"),
                                                    # stderr=open(f"main_{dataset}_{args.constraint}.err", "a"))
                            time.sleep(5)
                            processes[device_idx].append(proc)
                            pids_to_device[proc.pid] = device_idx
                            pids_to_procsid[proc.pid] = len(processes[device_idx]) - 1
                            any_free = False
                            for devid in processes:
                                if (len(processes[devid]) < args.pll_procs[devid]):
                                    # print ("any free", device_idx)
                                    device_idx = devid
                                    any_free = True
                                    break
                            if not any_free:
                                pid, _ = os.wait()
                                # print ("not(anyfree)", device_idx)
                                device_idx = pids_to_device[pid]
                            remove_ids = {devid: [] for devid in processes}
                            for proc in sum(processes.values(), []):
                                if proc.poll() is not None:
                                    # print("procpoll", proc.pid)
                                    devid = pids_to_device[proc.pid]
                                    # with open(f"main_{dataset}_{args.constraint}.out", "a") as f:
                                    #     f.write(f"({add_diff_step}, {rounding}, {gamma}, {param_vals}) finished\n")
                                    remove_ids[devid].append(proc.pid)
                            for devid in processes.keys():
                                processes[devid] = [y for y in processes[devid] if y.pid not in remove_ids[devid]]
                            # with open(f"main_{dataset}_{args.constraint}.out", "a") as f:
                            #     f.write (str({devid: [proc.pid for proc in processes[devid]] for devid in processes}))