#!python3

import argparse
import itertools
from pathlib import Path
from subprocess import Popen, PIPE
from copy import deepcopy
import time

import jinja2
import yaml

INVALID_CHARS = [' ', '\t', '\n', '\r', '\v', '\f', '\0', '\b', '\\', '/', ':', '*', '?', '"', '<', '>', '|', '%', '$', '!', '@', '&', '(', ')', '[', ']', '{', '}', ';', '`', '~', '#', '^', '=', '+', ',', "'"]


# Iterator yielding all combinations for each option:
def all_combinations(d):
    k, v = zip(*d.items())
    k = list(k)
    v = list(v)

    # unravel nested dicts
    while any(isinstance(v_i, dict) for v_i in v):
        new_k = []
        new_v = []
        for i, (k_i, v_i) in enumerate(zip(k, v)):
            if isinstance(v_i, dict):
                for j, (k_j, v_j) in enumerate(v_i.items()):
                    new_k.append(f'{k_i}.{k_j}')
                    new_v.append(v_j)
            else:
                new_k.append(k_i)
                new_v.append(v_i)
        k, v = new_k, new_v

    k_sweep = []
    v_sweep = []
    k_constant = []
    v_constant = []
    for i, (k_i, v_i) in enumerate(zip(k, v)):
        if isinstance(v_i, list):
            k_sweep.append(k_i)
            v_sweep.append(v_i)
        else:
            k_constant.append(k_i)
            v_constant.append(v_i)

    for c in itertools.product(*v_sweep):
        yield dict(zip(k_constant + k_sweep, v_constant + list(c)))
        

def check_args_for_errors(d):
    for k, v in d.items():
        if isinstance(v, dict):
            check_args_for_errors(v)
        else:
            for c in INVALID_CHARS:
                if c in k:
                    raise ValueError(f"Invalid character in key {k}: \'{c}\'")


def parse_sweepfile(sweepfile, template, dry_run=False, spacing=0):
    check_args_for_errors(sweepfile)
    experiment = sweepfile.pop('experiment')
    
    templateEnv = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath="./"))
    template = templateEnv.get_template(template)
    
    i = 0
    for cmb in all_combinations(sweepfile):
        i += 1
        print(f"# Submitting {cmb}", end="")
        all_args = ' '.join([f'--{k}={v}' for k, v in cmb.items()])
        template_args = deepcopy(cmb)
        template_args['args'] = all_args
        template_args['experiment'] = experiment
        template_args['job_id'] = f'{i:04d}'
        rendered = template.render(template_args)
        if dry_run:
            print()
            print()
            print(rendered)
            print("#" * 80)
            print()
            print()
        
        else:
            p = Popen(['sbatch'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
            stdout_data = p.communicate(input=rendered.encode())[0]
            print(" ...Done!")
            
            if spacing > 0:
                print(f"Sleeping for {spacing} seconds...")
                time.sleep(spacing)
    
    print(f"Submitted {i} jobs!")


def main(args):
    if args.sweepfile.is_dir():
        for sweepfile in args.sweepfile.iterdir():
            if sweepfile.suffix == '.yaml':
                print(f"Processing {sweepfile}...")
                with sweepfile.open('r') as fp:
                    sweepfile = yaml.safe_load(fp)
                parse_sweepfile(sweepfile, args.template, args.dry_run, args.spacing)
    else:
        with args.sweepfile.open('r') as fp:
            sweepfile = yaml.safe_load(fp)        
        parse_sweepfile(sweepfile, args.template, args.dry_run, args.spacing)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run multiple jobs using sbatch")
    parser.add_argument("sweepfile", type=Path, default="run_sweep.yaml",
                        help="Yaml file containing the relevant sweep parameters "
                             "or a directory containing multiple sweep files")
    parser.add_argument("template", type=str, default="run_tpl.sh",
                        help="Jinja2 template containing the template")
    parser.add_argument("--dry-run", action="store_true", help="Don't actually run anything")
    parser.add_argument("--spacing", type=int, default=0, help="Spacing between jobs in seconds")
    
    main(parser.parse_args())
