import os
import sys
import itertools
import subprocess

dry_run = '--dryrun' in sys.argv
clear = '--clear' in sys.argv
local = '--local' in sys.argv
double_book = '--double-book' in sys.argv
quad_book = '--quad-book' in sys.argv
prefix = "_" + sys.argv[sys.argv.index('--prefix') + 1] if '--prefix' in sys.argv else ""

if double_book:
    increment = 2
elif quad_book:
    increment = 4
else:
    increment = 1


if not os.path.exists("slurm_logs"):
    os.makedirs("slurm_logs")
if not os.path.exists("slurm_scripts"):
    os.makedirs("slurm_scripts")
code_dir = '.'
excluded_flags = {'main_file', 'embed_job'}


def construct_varying_keys(grids):
    all_keys = set().union(*[g.keys() for g in grids])
    merged = {k: set() for k in all_keys}
    for grid in grids:
        for key in all_keys:
            grid_key_value = grid[key] if key in grid else ["<<NONE>>"]
            merged[key] = merged[key].union(grid_key_value)
    varying_keys = {key for key in merged if len(merged[key]) > 1}
    return varying_keys


def construct_jobs(grids):
    jobs = []
    for grid in grids:
        individual_options = [[{key: value} for value in values]
                              for key, values in grid.items()]
        product_options = list(itertools.product(*individual_options))
        jobs += [{k: v for d in option_set for k, v in d.items()}
                 for option_set in product_options]
    return jobs



def construct_flag_string(job):
    """construct the string of arguments to be passed to the script"""
    flagstring = ""
    for flag in job:
        if not flag in excluded_flags:
            if isinstance(job[flag], bool):
                if job[flag]:
                    flagstring = flagstring + " --" + flag
                #else:
                #    print("WARNING: Excluding 'False' flag " + flag)
            else:
                flagstring = flagstring + " --" + flag + " " + str(job[flag])
    return flagstring

def construct_name(job, varying_keys):
    """construct the job's name out of the varying keys in this sweep"""
    jobname = basename
    for flag in job:
        if flag in varying_keys:
            jobname = jobname + "_" + flag + str(job[flag])
    return jobname

main_grid = [
    {
        'main_file': ['main'],
        'iterations': [300000],
        'n': [4,6],
        'hidden_dim': [30],
        'lr': [0.0005],
        'anti_dim': [1, 20, 40, 60]
    }
]



basename = "antisym"
basename += prefix

embed_grid = main_grid

embed_jobs = construct_jobs(embed_grid)
embed_varying_keys = construct_varying_keys(embed_grid)



if dry_run:
    print("NOT starting {} jobs:".format(len(embed_jobs)))
else:
    print("Starting {} jobs:".format(len(embed_jobs)))

for i, job in enumerate(embed_jobs):
    jobname = construct_name(job, embed_varying_keys)
    flagstring = construct_flag_string(job)
    # flagstring = flagstring + " --name " + jobname

    slurm_log_dir = 'slurm_logs/' + jobname 
    os.makedirs(os.path.dirname(slurm_log_dir), exist_ok=True)


    jobcommand = "srun python {}.py{}".format(job['main_file'], flagstring)

    embed_jobs[i]['name'] = jobname
    if local:
        gpu_id = i % 4
        log_path = "slurm_logs/" + jobname
        os.system("env CUDA_VISIBLE_DEVICES={gpu_id} {command} > {log_path}.out 2> {log_path}.err &".format(
                gpu_id=gpu_id, command=jobcommand, log_path=log_path))

    else:
        slurm_script_path = 'slurm_scripts/' + jobname + '.slurm'
        slurm_script_dir = os.path.dirname(slurm_script_path)
        os.makedirs(slurm_script_dir, exist_ok=True)

        job_start_command = "sbatch --parsable " + slurm_script_path

        with open(slurm_script_path, 'w') as slurmfile:
            slurmfile.write("#!/bin/bash\n")
            slurmfile.write("#SBATCH --job-name" + "=" + jobname + "\n")
            slurmfile.write("#SBATCH --open-mode=append\n")
            slurmfile.write("#SBATCH --output=slurm_logs/" +
                            jobname + ".out\n")
            slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n")
            slurmfile.write("#SBATCH --export=ALL\n")
            slurmfile.write("#SBATCH --time=1-00\n")
            slurmfile.write("#SBATCH -N 1\n")

            slurmfile.write("#SBATCH --gres=gpu:1\n")

            slurmfile.write("#SBATCH --constraint=turing|volta\n")

            slurmfile.write("\n")
            slurmfile.write("source activate prime2\n")
            slurmfile.write("module load cuda-10.1\n")



            slurmfile.write(jobcommand)
            slurmfile.write("\n")

        if not dry_run:
            job_subproc_cmd = ["sbatch", "--parsable", slurm_script_path]
            start_result = subprocess.run(job_subproc_cmd, stdout=subprocess.PIPE)
            jobid = start_result.stdout.decode('utf-8')
            embed_jobs[i]['jobid'] = jobid



