import os

import click
import numpy as np
from ruamel import yaml


@click.command()
@click.option('--config-dir', type=click.Path(), default='ablations/configs',
              help='Folder that contains the configurations')
@click.option('--snakefile', type=click.Path(), default='ablations/Snakefile',
              help='Where to save the snakefile')
@click.option('--output-dir', type=click.Path(), default='ablations/outputs',
              help='Folder where training logs will be saved')
@click.option('--skip-existing', is_flag=True,
              help='Do not create rules for configs with existing output files')
@click.option('--high-priority', '-p', multiple=True,
              help='Which jobs to run with high priority')
def main(config_dir, output_dir, snakefile, skip_existing, high_priority):
    """
    Creates a Snakefile that runs the ablation studies in the given folder.
    """

    rule_template = r'''
rule {rule_name}:
    input: "{input_file}"
    output: "{output_file}"
    priority: {priority}
    shell: "bash ablations/run-conf.sh {{input}} {{output}} {{rule}}"
'''

    experiments = []
    rows = []
    for fname in os.listdir(config_dir):
        if fname.endswith('.yaml'):
            name = fname.replace('.yaml', '')
            print(f'Handling {name} ...')

            ps = name.split('_')
            pname = '_'.join(ps[:-1]) if len(ps[-1]) == 2 and ps[-1][0] == 'r' and ps[-1][1].isdigit() \
                    else name
            priority = 100 if pname in high_priority else 0

            output_file = os.path.join(output_dir, fname.replace('.yaml', '.out'))
            if skip_existing and os.path.exists(output_file):
                print('Skipping', output_file)
                continue

            experiments.append(output_file)
            rows.append(rule_template.format(
                rule_name=f'{name.replace(".", "_")}',
                input_file=os.path.join(config_dir, fname),
                output_file=output_file,
                priority=priority,
            ))

    with open(snakefile, 'w') as f:
        f.write('rule all:\n    input:\n')
        for e in sorted(experiments):
            f.write(f'        "{e}",\n')
        f.write('\n'.join(rows))

    print('%d runs created' % len(experiments))


if __name__ == '__main__':
    main()
