import json
import logging
import shutil
import sys
import typing as ty
from datetime import datetime
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable

import egr
from egr import util
from workflows import tasks, config

LOG = logging.getLogger('run')


def start_iteration(cfg, sample_id, fold):
    try:
        for step in cfg.steps:
            LOG.info('Running step: %s', step['name'])
            func: Callable = getattr(tasks, step['type'])
            success = func(cfg, step, str(sample_id))
            if not success:
                return False
    except KeyboardInterrupt as err:
        LOG.warning('^C keyboard interrupt, %s', err)
        sys.exit(0)
    except (AttributeError, AssertionError, FileNotFoundError) as err:
        LOG.exception('%s application error', err)
        return False
    return True


def generate_run_configs(cfg) -> ty.List:
    cfgs: ty.List = []
    for sample_id in cfg.sample_ids:
        for variant in cfg.variants:
            for fold in cfg.folds:
                cfgs.append(
                    [
                        dict(
                            sample_id=sample_id,
                            variant=variant,
                            fold=fold,
                            iteration=iteration,
                        )
                        for iteration in cfg.iterations
                    ]
                )
    return cfgs


PADDING = '=' * 32


def main(args):
    assert Path('./external/gaston').exists()

    cfg = config.WorkflowConfig(args)
    if hasattr(cfg, 'clean_root') and cfg.clean_root:
        if cfg.run_root.exists():
            LOG.info('Cleaning root directory %s', cfg.run_root)
            shutil.rmtree(cfg.run_root)

    last_iteration = cfg.iterations[-1]
    num_folds: int = (cfg.fold['end'] - cfg.fold['begin']) + 1

    flows = generate_run_configs(cfg)
    for i, flow in enumerate(flows):
        fold: int = cfg.folds[i % num_folds]
        LOG.info('%s FOLD: %2d %s', PADDING, fold, PADDING)
        fold_begin = datetime.now()
        for task_args in flow:
            iteration = task_args['iteration']
            timings_dir = (
                cfg.run_root
                / task_args['variant']
                / 'timings'
                / ('fold-' + str(task_args['fold']))
            )
            timings_dir.mkdir(parents=True, exist_ok=True)
            timings_path = timings_dir / f'{iteration}.json'
            timings = {}
            for attr in cfg.steps:
                LOG.info('attr:%s', attr)
                step = attr['name'].lower()
                if iteration == last_iteration and attr['type'] != 'train':
                    LOG.info(
                        'Last iteration, skipping non-training step %s',
                        attr['type'],
                    )
                    break
                s = '|'.join([f'{k}:{v}' for k, v in task_args.items()])
                LOG.info('FOLD:%2d,Step:%s:%s', fold, attr['name'], s)
                func: ty.Callable = getattr(tasks, attr['type'])
                ret, time_data = func(cfg, attr, **task_args)
                for step_name, step_time in time_data.items():
                    timings.update({step_name: step_time})
                # timings.update(
                #     {
                #         step: {
                #             'begin': step_begin,
                #             'end': step_end,
                #         }
                #     }
                # )
                if not ret:
                    break
            LOG.info('Writing times to %s', timings_path)
            util.save_json(timings, timings_path)
        fold_end = datetime.now()
        LOG.info('FOLD:%2d duration: %s', fold, fold_end - fold_begin)
        LOG.info('%s FOLD: %2d %s', PADDING, fold, PADDING)


def run_fold(cfg, sample_id, fold: int):
    for iteration in cfg.iterations:
        run_iteration(cfg, sample_id, iteration, fold)


def run_iteration(cfg, sample_id, iteration, fold):
    cfg.set_iteration(iteration)
    LOG.info('Starting sample:%s, iter:%s', sample_id, iteration)
    begin = datetime.now()
    success = start_iteration(cfg, sample_id, fold)
    end = datetime.now()
    if not success:
        LOG.error('exiting')
        sys.exit(1)
    LOG.info('Finished iteration %s in %s', iteration, end - begin)


if __name__ == '__main__':
    logging.getLogger('httpx').setLevel(logging.WARNING)
    logging.getLogger('matplotlib').setLevel(logging.WARNING)

    parser = ArgumentParser()
    egr.add_log_argument(parser)
    parser.add_argument('-c', '--config', type=Path, required=True)
    parser.add_argument(
        '--run-defaults',
        type=Path,
        default='run_configs/run_defaults.yml',
    )
    parser.add_argument(
        '--pattern-master',
        type=Path,
        default='run_configs/pattern_master.yml',
    )
    args = parser.parse_args()
    egr.init_logging(level_name=args.log_level)

    logfile_name = f'{args.config.stem}-{egr.util.now_ts()}.log'
    args.log_dir.mkdir(parents=True, exist_ok=True)
    logfile_path = args.log_dir / logfile_name

    LOG.info('Writing to log file %s', logfile_path)
    file_handler = logging.FileHandler(logfile_path, mode='w')
    file_handler.setLevel(logging.DEBUG)
    logging.getLogger().addHandler(file_handler)

    start = datetime.now()
    main(args)
    LOG.info('Elapsed time %s', datetime.now() - start)
