import click
import os
import textwrap
import shutil
import inspect
import signal
import re
import subprocess
from jinja2 import Template
from graph_learning.main import get_version

@click.command()
@click.option('--experiment', '-e', required=True,
              help='experiment name')
@click.option('--experiment-set', '-es', default='test',
              help='experiment set name')
@click.option('--scripts', '-s', required=True, multiple=True,
              help='config files to be exceuted in order')
@click.option('--scripts-base', type=int, default=0,
              help='config file saving name base')
@click.option('--test-scripts',
              help='test on which config file')
@click.option('--var', multiple=True,
              help='variable assign file')
@click.option('--output-dir', default='./results',
              help='results output dir')
@click.option('--deamon', '-d', is_flag=True)
@click.option('--prints', default='nohup.out',
              help='for deamon,(default nohup.out)')
@click.option('--cuda')
@click.option('--run-cmd', '-r', required=True,
              help='running subcommand (train/test)')
@click.option('--times', type=int, default=1,
              help='times of replication')
@click.option('--data-versions', type=int, default=1,
              help='data version used for random split experiment settings')
@click.option('--force', '-f', is_flag=True,
              help='clear the experiment dir')
def main(experiment,
         experiment_set,
         scripts,
         scripts_base,
         test_scripts,
         var,
         output_dir,
         deamon,
         prints,
         cuda,
         run_cmd,
         times,
         data_versions,
         force):
    cuda_cmd = f'CUDA_VISIBLE_DEVICES={cuda}'
    if deamon:
        main_cmd = 'nohup stdbuf -oL python -W ignore::UserWarning -u'
        #main_cmd = 'nohup stdbuf -oL graph_learning'
        deamon_cmd = f'> {prints} 2>&1 &'
    else:
        main_cmd = 'python -W ignore::UserWarning'
        #main_cmd = 'graph_learning'
    main_cmd = ' '.join([main_cmd, 'main.py'])

    if isinstance(var, tuple):
        experiment = experiment + '/' + '_'.join([s[s.rindex('/')+1:] for s in var])
    experiment_path = os.path.join(output_dir, 'experiments', experiment_set, experiment)
    if force:
        shutil.rmtree(experiment_path, ignore_errors=True)

    def clean_str(s, sep='\n'):
        return sep.join([ss for ss in s.split('\n') if ss])

    def run_script(experiment_path, script_path, script_save_name, **formats):
        with open(script_path, 'r') as f:
            cmd = f.read()
        cmd = re.sub(r'#[^\n]+\n', '', cmd)
        cmd = clean_str(cmd, '\n')

        if isinstance(var, tuple):
            var_assign = ''
            for varf in var:
                with open(varf, 'r') as f:
                    var_assign += ('\n' + f.read())
            var_assign = clean_str(var_assign, ',')
            cmd = eval(f'Template("""{cmd}""").render({var_assign})')
            cmd = eval(f'{repr(cmd)}.format({var_assign})')
        cmd = cmd.format(**formats)

        script = cmd
        cmd = cmd.replace('\n', ' ')

        # set global config
        if cuda.strip() == '':
            use_cuda_cmd = ''
        else:
            use_cuda_cmd = '--use-cuda'

        gconf_cmd = f'global_ {use_cuda_cmd} --output-dir={experiment_path} --script="{script}" --script-save-name={script_save_name}'

        cmd = ' '.join([cuda_cmd, main_cmd, gconf_cmd, cmd.strip(), run_cmd])
        if deamon:
            cmd = ' '.join([cmd, deamon_cmd])

        subprocess.run(cmd, shell=True)
        # os.killpg(os.getpgid(pro.pid), signal.SIGTERM)

    device = 'cpu'
    if cuda is not None:
        device = f'cuda:0'

    if test_scripts is None:
        for v in range(times):
            for i, script in enumerate(scripts):
                experiment_path_v = os.path.join(experiment_path, f'version_{v}')
                script_save_name = f'script_{scripts_base + i}'
                script_path = os.path.join(experiment_path_v, script_save_name)

                if not os.path.exists(script_path):
                    script_path = script

                run_script(experiment_path_v, script_path, script_save_name, version=v, data_version=v % data_versions, device=device)
    else:
        for v in range(times):
            experiment_path_v = os.path.join(experiment_path, f'version_{v}')
            script_save_name = f'script_{test_scripts}'
            script_path = os.path.join(experiment_path_v, script_save_name)
            run_script(experiment_path_v, script_path, script_save_name, version=v, data_version=v % data_versions, device=device)


if __name__ == '__main__':
    main()
