#!/usr/bin/env python

import argparse
import json
import re
import os
import pdb
import random
import shutil
import subprocess
import sys
import tarfile
import time
import datetime

import logging
logger = logging.getLogger(__name__)



def tar_code(code_dir, config_path, output_path, prune_dirnames=None, prune_paths=None):
    if prune_dirnames is None:
        prune_dirnames = []
    if prune_paths is None:
        prune_paths = []

    paths_to_archive = []

    for root, dirnames, filenames in os.walk(code_dir, followlinks=True):
        if len(root) > 512:
            raise ValueError("Abnormally long path detected.  This script follows symlinks; please check if there is an infinite loop somewhere.")

        delidxs = []
        for idx, dirname in enumerate(dirnames):
            fullpath = os.path.join(root, dirname)
            if any(dirname == pname for pname in prune_dirnames) or any(os.path.samefile(fullpath, ppath) for ppath in prune_paths):
                delidxs.append(idx)
        for idx in reversed(delidxs):
            del dirnames[idx]

        for filename in filenames:
            if filename.endswith('.py'):
                paths_to_archive.append(os.path.join(root, filename))

    with tarfile.open(output_path, 'x') as tarf:
        for path in paths_to_archive:
            tarf.add(path, arcname=os.path.join('code', os.path.relpath(path, code_dir)))
        tarf.add(config_path, arcname=os.path.join('code', 'config.json'))

def make_job_wrapper(output_dir):
    wrapper_pyfile = os.path.join(output_dir, 'run_job.py')
    wrapper_pyfile_lines = [
        'import datetime',
        'import json',
        'import os',
        'import sys',
        'import socket',
        'import subprocess',
        'import time',
        '',
        'output_dir = {}'.format(escape_string_python(output_dir)),
        'with open(os.path.join(output_dir, "meta_config.json"), "rt") as f:',
        '    meta_config = json.load(f)',
        'main_pyfile = meta_config["main_pyfile"]',
        'profile = meta_config["do_profile"]',
        'hostname = socket.gethostname()',
        '',
        'print("Starting run for dir: {}".format(output_dir), file=sys.stderr)',
        '',
        'start_time = datetime.datetime.utcnow().timestamp()',
        'with open(os.path.join(output_dir, "pre_run_info.json"), "x") as f:',
        '    json.dump({"start_time_gmt": start_time, "hostname": hostname, "env": dict(os.environ)}, f)',
        '',
        'pyargs = ["python3"]',
        'if profile:',
        '    pyargs += ["-m", "cProfile", "-o", os.path.join(output_dir, "profile.out")]',
        '',
        'start_perf = time.perf_counter()',
        'retcode = subprocess.call(pyargs + [main_pyfile, os.path.join(output_dir, "tmp_rundir", "code", "config.json")], cwd=os.path.join(output_dir, "tmp_rundir", "code"))',
        'elapsed = time.perf_counter() - start_perf',
        '',
        'with open(os.path.join(output_dir, "post_run_info.json"), "x") as f:',
        '    json.dump({"exit_status": retcode, "total_runtime_seconds": elapsed, "hostname": hostname, "start_timestamp": start_time, "end_timestamp": datetime.datetime.utcnow().timestamp()}, f)',
        'print("Finished run for dir: {}".format(output_dir), file=sys.stderr)',
        'sys.exit(retcode)',
    ]
    with open(wrapper_pyfile, 'x') as f:
        f.write('\n'.join(wrapper_pyfile_lines) + '\n')

    return wrapper_pyfile



def prepare_experiment(experiment_id, main_pyfile, code_dir, config_dir, experiments_dir, num_trials, seed_map, overwrite_unfinished=False, start_trial_num=1, base_config=None, do_profile=False):
    """
    seed_map: A map from trial index to random seeds.  E.g., it could be a List
              or a Dict, or some class that overrides __getitem__.  It can be
              an integer X, in which case a trial's seed is X+trial_index.  It
              can also be a random.Random object, in which case the seeds are
              created from the generator.
    """
    if isinstance(seed_map, random.Random):
        # We need to generate all the way up so we get the same mapping no matter what start_trial_num is
        seed_map = {i: seed_map.randint(0, 2**32-1) for i in range(0, start_trial_num+num_trials)}
    elif isinstance(seed_map, int):
        seed_map = {x: (seed_map + x) for x in range(start_trial_num, start_trial_num+num_trials)}
    elif isinstance(seed_map, list):
        if len(seed_map) == num_trials:
            seed_map = {(start_trial_num+i): seed_map[i] for i in range(num_trials)}
        elif len(seed_map) != (start_trial_num+num_trials):
            raise ValueError("Wrong number of seeds given in list")

    experiments_dir = os.path.abspath(experiments_dir)

    if os.path.isabs(main_pyfile):
        raise ValueError("main_pyfile should be relative to code_dir")

    meta_config = dict()

    logger.info('Initializing config...')
    if base_config is None:
        with open(os.path.join(config_dir, f'config_v{experiment_id}.json')) as f:
            base_config = json.load(f)
    else:
        base_config = base_config.copy()

    overwritable_keys = ['output_dir', 'seed']

    for key in overwritable_keys:
        if key in base_config:
            raise ValueError(f"'{key}' already in config; refusing to overwrite")

    output_dirs = []
    for trial_num in range(start_trial_num, start_trial_num + num_trials):
        config = base_config.copy()
        output_dirname = f'experiment_v{experiment_id}_r{trial_num}'

        config['seed'] = seed_map[trial_num]

        output_dir = os.path.join(experiments_dir, output_dirname)
        if os.path.exists(output_dir):
            if os.path.exists(os.path.join(output_dir, 'post_run_info.json')):
                logger.info("Trial {} already exists ({}); skipping...".format(trial_num, output_dir))
                continue
            elif overwrite_unfinished and not os.path.exists(os.path.join(output_dir, 'output')):
                logger.info('Overwriting trial {} which exists ({}) but was not finished...'.format(trial_num, output_dir))
                shutil.rmtree(output_dir)
            else:
                raise ValueError("Unfinished trial found ({}).".format(output_dir))
        os.makedirs(output_dir)
        output_dirs.append(output_dir)

        config['output_dir'] = os.path.join(output_dir, 'output')

        for key in config.keys():
            if not isinstance(config[key], str):
                continue
            if '${TRIAL}' in config[key]:
                new_str = config[key].replace('${TRIAL}', str(trial_num))
                logger.info('Replacing {} with {} for config key {}'.format(config[key], new_str, key))
                config[key] = new_str

        with open(os.path.join(output_dir, 'config.json'), 'x') as f:
            json.dump(config, f)

        meta_config['experiment_id'] = experiment_id
        meta_config['trial_num'] = trial_num
        meta_config['main_pyfile'] = main_pyfile
        meta_config['creation_timestamp'] = datetime.datetime.utcnow().timestamp()
        meta_config['output_dir'] = output_dir
        meta_config['do_profile'] = do_profile
        meta_config['runconfig'] = config
        with open(os.path.join(output_dir, 'meta_config.json'), 'x') as f:
            json.dump(meta_config, f)

        make_job_wrapper(output_dir)

        logger.info(f'Backing up files for trial {trial_num}...')
        PRUNE_PATHS = [config_dir]
        PRUNE_DIRNAMES = ['.git', '__pycache__']
        tar_path = os.path.join(output_dir, 'code.tar')

        tar_code(code_dir, os.path.join(output_dir, 'config.json'), tar_path, prune_dirnames=PRUNE_DIRNAMES, prune_paths=PRUNE_PATHS)

        with tarfile.open(tar_path, 'r:*') as tarf:
            tarf.extractall(path=os.path.join(output_dir, 'tmp_rundir'))

    if len(output_dirs) == 0:
        logger.warning("All trial directories already exist.  Exiting.")

    return output_dirs


def escape_string_bash(string):
    return json.dumps(str(string))

def escape_string_python(string):
    return repr(str(string))

def run_jobs(args, output_dirs):

    n_gpus = args.n_gpus if args.n_gpus is not None else 1
    n_cpus = args.n_cpus if args.n_cpus is not None else n_gpus

    if args.backend == 'basic':
        for output_dir in output_dirs:
            pyfile = os.path.join(output_dir, 'run_job.py')
            retcode = subprocess.call(['python3', pyfile], cwd=os.path.join(output_dir, 'tmp_rundir', 'code'))
    elif args.backend == 'slurm':
        for output_dir in output_dirs:
            slurm_pyfile = os.path.join(output_dir, 'run_job.py')

            slurm_filename = os.path.join(output_dir, 'sbatch_script.sh')
            exp_code_dir = os.path.join(output_dir, 'tmp_rundir', 'code')
            raise NotImplementedError("Need to specify conda env to activate")
            slurm_lines = [
                "#!/bin/bash -l",
                "#SBATCH --ntasks=1",
                "#SBATCH --cpus-per-task={:d}".format(n_cpus),
                "#SBATCH --gpus={:d}".format(n_gpus),
                "#SBATCH --gpus-per-node={:d}".format(n_gpus),
                "#SBATCH --output={output_path}".format(output_path=escape_string_bash(os.path.join(output_dir, 'slurm_output.out'))),
                "#SBATCH --job-name={job_name}".format(job_name=escape_string_bash(os.path.basename(output_dir))),
                '',
                'echo user $USER',
                'echo host $(hostname)',
                'echo pwd $(pwd)',
                '',
                'source ~/.bashrc',
                'source ~/.bash_aliases',
                'source ~/miniconda3/etc/profile.d/conda.sh',
                'conda activate {conda_env}'.format(conda_env=escape_string_bash('al')),
                '',
                'cd {exp_code_dir}'.format(exp_code_dir=escape_string_bash(exp_code_dir)),
                'python3 {slurm_pyfile}'.format(slurm_pyfile=escape_string_bash(slurm_pyfile)),
            ]
            with open(slurm_filename, 'x') as f:
                f.write('\n'.join(slurm_lines) + '\n')

            retcode = subprocess.call(['sbatch', slurm_filename], cwd=exp_code_dir)
            if retcode != 0:
                logger.error("sbatch failed for {}".format(output_dir))
    else:
        raise ValueError("unknown backend {}".format(args.backend))

def main():
    logging.basicConfig(
        format="%(asctime)s [%(levelname)s] (run_experiment %(name)s):  %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
        ],
    )

    parser = argparse.ArgumentParser()
    parser.add_argument('--profile', default=False, action='store_true', help='Run with cProfile and save the profile to profile.out in the output directory')
    parser.add_argument('--backend', default='basic', choices=['drexo', 'slurm', 'basic'])
    parser.add_argument('--n_gpus', default=None, type=int, help='Number of gpus to allocate.  WARNING: Experimental: Does not work with all backends, might fail silently, and default values may differ between backends')
    parser.add_argument('--n_cpus', default=None, type=int, help='Number of cpus each job will use.  WARNING: Experimental: Does not work with all backends, might fail silently, and default values may differ between backends')
    parser.add_argument('--overwrite_unfinished', default=False, action='store_true', help='Whether to overwrite directories for trials that haven\'t been finished (or possible not even started) yet')
    parser.add_argument('--code_dir', default='.')
    parser.add_argument('--config_dir', default='<code_dir>/runconfigs/')
    parser.add_argument('--experiments_dir', default='../data/experiments/')
    parser.add_argument('--num_trials', default=1, type=int)
    parser.add_argument('--start_trial_num', default=1, type=int, help="The trial num to start from.  Trial numbers will go from this number (inclusive) to this number plus the number of trials (exclusive).")
    parser.add_argument('--seed', default=None, help="Random seed to start from for trials (trial 1 uses seed+1, trial 2 is seed+2, etc).  If not given, a random one is chosen.")
    parser.add_argument('main_pyfile', help='filename of main python file to run inside code_dir')
    parser.add_argument('experiment_id', nargs='+')
    args = parser.parse_args()

    if args.seed is None:
        # We only go to 2**31, so that we don't end up with a seed bigger than
        # 2**32 after adding trial_index in prepare_experiment()
        args.seed = random.Random().randint(0, 2**31)
    else:
        use_fixed = (args.seed[-1] == ',')
        args.seed = [int(x) for x in args.seed.strip(',').split(',')]
        if len(args.seed) == 1 and not use_fixed:
            args.seed = args.seed[0]
    logger.info("Using base seed {}".format(args.seed))

    if args.config_dir == '<code_dir>/runconfigs/':
        args.config_dir = os.path.join(args.code_dir, 'runconfigs')

    output_dirs = []
    for _experiment_id in args.experiment_id:
        config_obj = None
        experiment_id = int(_experiment_id)

        logger.info('Preparing experiment {}...'.format(experiment_id))
        output_dirs.extend(
                prepare_experiment(
                    experiment_id,
                    args.main_pyfile,
                    args.code_dir,
                    args.config_dir,
                    args.experiments_dir,
                    args.num_trials,
                    args.seed,
                    overwrite_unfinished=args.overwrite_unfinished,
                    start_trial_num=args.start_trial_num,
                    base_config=config_obj))

    run_jobs(args, output_dirs)

if __name__ == '__main__':
    main()

