import os
import sys
import argparse
import ruamel.yaml
import itertools
import subprocess
import time

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--constraint", type=str)
parser.add_argument("--eq", action='store_true')
parser.add_argument("--pll_procs", type=int, nargs='+')
parser.add_argument("--schedule_gammas", type=str, nargs='+', default=["poly", "fixed"])
parser.add_argument("--devices", type=int, nargs='+')
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args(sys.argv[1:])

dataset_config= f"sample_{args.dataset}"

if args.constraint in ["valency", "atomCount", "molWeight", "regression"] or args.constraint.startswith ("prop"):
    constraint_config = f"{args.constraint}_{args.dataset}"
else:
    constraint_config = f"{args.constraint}"

##
## Config values
##
# add_diff_steps = [0, 100, 500, 1000, 1500, 2000, 2500]
# if args.constraint == "none":
#     gammas = [0]
#     rounding_schemes = ['none']
# else:
#     gammas = [0.1, 0.25, 0.5, 0.75, 0.9, 1.0]
#     rounding_schemes = ["none", "randomized", "repeated"]

# gammas = [0.0, 0.2, 0.5, 0.6]
add_diff_steps = [0]
gammas = [1.0] #[0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
burnins = [0]
rounding_schemes = ["none"]
solve_orders = ["cpj"]
seeds = [args.seed]#, 123, 567]

# if args.constraint == 'valency':
#     add_diff_steps = [0, 500]
#     gammas = [0.0, 0.1, 0.25, 0.5] #, 0.75, 1.0]
#     rounding_schemes = ["none", "heurval"]
#     solve_orders = ["cpj"]
# elif args.constraint.startswith('eig'):
#     gammas = [0.0, 0.25, 0.5, 0.75, 1.0]
#     rounding_schemes = ["none", "randomized"]
# elif args.constraint == 'budget':
#     add_diff_steps = [0, 500]
#     rounding_schemes = ["none"]
#     gammas = [0.25, 0.5, 0.75, 1.0]
#     burnins = [0] #100, 250, 500, 750, 900]
#     solve_orders = ["cpj"]

schedules = args.schedule_gammas
schedule_gamma_params = []
schedule_gammas = []

for schedule in schedules:
    istart_nparams = len(schedule_gamma_params)
    if schedule == 'cyclical':
        cycle_sizes = [10, 20, 50, 100, 200, 500]
        schedule_gamma_params += [[x] for x in cycle_sizes]
    elif schedule == 'steprise':
        gammas = [1.]
        gamma_inits = [1e-4, 1e-3, 5e-3, 1e-2]
        gamma_mulps = [2, 4, 8]
        gamma_steps = [10, 20, 50, 100]
        # [init, mulp, step]
        schedule_gamma_params += [[1e-4, 8, 10], [1e-3, 8, 10], 
                                    [1e-4, 4, 10], [1e-3, 4, 10],
                                    [1e-4, 2, 10], [1e-3, 2, 10], [5e-3, 2, 10],
                                    [1e-4, 8, 10], [1e-3, 8, 10], 
                                    [1e-4, 4, 10], [1e-3, 4, 10],
                                    [1e-4, 2, 10], [1e-3, 2, 10], [1e-2, 2, 10],]
    elif schedule == 'poly':
        gammas = [1.]
        poly_pows = [1, 5]
        gamma_inits = [0, 1e-1]#, 1e-1] #[0, 1e-4, 1e-3, 1e-2, 1e-1]
        schedule_gamma_params += [[x, y] for x, y in itertools.product (gamma_inits, poly_pows)]
    elif schedule == 'polystep':
        gammas = [1.]
        poly_pows = [1, 2, 3, 4, 5]
        gamma_inits = [0, 1e-2, 1e-1] #[0, 1e-4, 1e-3, 1e-2, 1e-1]
        steps = [10, 20, 50, 100]
        schedule_gamma_params += [[x, y, z] for x, y, z in zip (gamma_inits, poly_pows, steps)]    
    elif schedule == 'polymid':
        gammas = [1.]
        poly_pows = [1, 2, 3, 4, 5]
        gamma_inits = [0, 1e-2, 1e-1] #[0, 1e-4, 1e-3, 1e-2, 1e-1]
        one_steps = [100, 200, 500, 750]
        schedule_gamma_params += [[x, y, z] for x, y, z in itertools.product (gamma_inits, poly_pows, one_steps)]
    elif schedule == 'fixed':
        schedule_gamma_params += [[1]] # [[0], [1]]
    schedule_gammas += [schedule] * (len(schedule_gamma_params) - istart_nparams)

print (list(zip (schedule_gammas, schedule_gamma_params)))

if args.constraint.startswith('molWeight'):
    schedule_gamma_params = [[0, 0], [0, 0.1], [0, 0.2], [0, 0.5]]
    schedule_gammas += [['poly', 'poly', 'poly', 'poly']]

yaml = ruamel.yaml.YAML()
constraint_vals = yaml.load(open(f"config/constraints/master_{args.constraint}.yaml", 'r'))
if args.dataset not in constraint_vals:
    exit ()

all_params = constraint_vals[f"{args.dataset}"]
# param_names = constr_params.keys()

constr_config = yaml.load(open(f"config/constraints/{constraint_config}.yaml", 'r'))
if "method" not in constr_config:
    constr_config["method"] = {'op':'proj', 'gamma': 0}

##
## Parallel subprocesses 
##
device_idx = 0
processes = {i: [] for i in range(len(args.devices))}
pids_to_device = {}
pids_to_procsid = {}

for seed in seeds:
    for solve_order in solve_orders:
        for sch_gamma, sch_gamma_params in zip(schedule_gammas, schedule_gamma_params):
            for burnin in burnins:
                for add_diff_step in add_diff_steps:
                    for rounding in rounding_schemes:
                        for gamma in gammas:
                            for param_vals in all_params: #itertools.product(*all_params):
                                # print (solve_order, add_diff_step, rounding, gamma, param_vals)
                                # continue
                                # with open(f"main_{args.dataset}_{args.constraint}.out", "a") as f:
                                #     f.write(f"{device_idx} ({add_diff_step}, {rounding}, {gamma}, {param_vals}) started\n")
                                # print (f"{device_idx} ({add_diff_step}, {rounding}, {gamma}, {param_vals}) started\n")
                                print (seed, solve_order, sch_gamma, sch_gamma_params, burnin, 
                                        add_diff_step, rounding, gamma, param_vals, flush=True)
                                constr_config["add_diff_step"] = add_diff_step
                                constr_config["burnin"] = burnin
                                constr_config["schedule"] = {'gamma': sch_gamma, 'params': sch_gamma_params}
                                constr_config["rounding"] = rounding
                                constr_config["method"]["gamma"] = gamma
                                constr_config["method"]["solve_order"] = solve_order
                                constr_config["implicit"] = False
                                constr_config["eq"] = args.eq
                                if args.constraint in ['valency', 'atomCount']:
                                    constr_config['params'] = [param_vals]
                                elif args.constraint in ['nedges', 'nedgesl2']:
                                    constr_config['params'] = ['zeros', param_vals]
                                elif args.constraint.startswith('molWeight'):
                                    constr_config['params'][1] = param_vals
                                elif args.constraint == "cheeger":
                                    constr_config["params"] = param_vals
                                elif args.constraint.startswith ("propIn"):
                                    if args.dataset == 'qm9':
                                        if args.constraint.split('-')[1] == 'homo':
                                            constr_config['params'][1] = -0.2605 - param_vals
                                            constr_config['params'][2] = -0.2605 + param_vals
                                        elif args.constraint.split('-')[1] == 'mu':
                                            constr_config['params'][1] = 1.8049 - param_vals
                                            constr_config['params'][2] = 1.8049 + param_vals
                                elif args.constraint.startswith ("prop"):
                                    constr_config['params'][1] = param_vals
                                else:
                                    constr_config["params"] = [param_vals] if type(param_vals) is not list else param_vals
                                # for param_name, param_val in zip(param_names, param_vals):
                                #     constr_config["params"][param_name] = param_val
                                yaml.dump(constr_config, open(f"config/constraints/{constraint_config}.yaml", "w"))
                                device = 'cuda:' + str(args.devices[device_idx])
                                proc = subprocess.Popen(["python", "main.py", "--type", "sample", "--config", dataset_config,
                                                        "--constr_config", constraint_config, "--device", device, 
                                                        "--seed", str(seed)])
                                                        # stdout=open(f"main_{args.dataset}_{args.constraint}.out", "a"),
                                                        # stderr=open(f"main_{args.dataset}_{args.constraint}.err", "a"))
                                time.sleep(5)
                                processes[device_idx].append(proc)
                                pids_to_device[proc.pid] = device_idx
                                pids_to_procsid[proc.pid] = len(processes[device_idx]) - 1
                                any_free = False
                                for devid in processes:
                                    if (len(processes[devid]) < args.pll_procs[devid]):
                                        # print ("any free", device_idx)
                                        device_idx = devid
                                        any_free = True
                                        break
                                if not any_free:
                                    pid, _ = os.wait()
                                    # print ("not(anyfree)", device_idx)
                                    device_idx = pids_to_device[pid]
                                remove_ids = {devid: [] for devid in processes}
                                for proc in sum(processes.values(), []):
                                    if proc.poll() is not None:
                                        # print("procpoll", proc.pid)
                                        devid = pids_to_device[proc.pid]
                                        # with open(f"main_{args.dataset}_{args.constraint}.out", "a") as f:
                                        #     f.write(f"({add_diff_step}, {rounding}, {gamma}, {param_vals}) finished\n")
                                        remove_ids[devid].append(proc.pid)
                                for devid in processes.keys():
                                    processes[devid] = [y for y in processes[devid] if y.pid not in remove_ids[devid]]
                                # with open(f"main_{args.dataset}_{args.constraint}.out", "a") as f:
                                #     f.write (str({devid: [proc.pid for proc in processes[devid]] for devid in processes}))