import argparse
import logging
import os
from pathlib import Path

import wandb

import methods
import utils
from evaluation.load import get_evaluator
from hyperparams.load import get_config
from run import load_checkpoint


def eval_id(cur_id, epoch, bs=None, split=None):
    config = get_config()
    args, run_path, device, model, logger, checkpoint = _prepare(
        cur_id, config, bs, split, epoch
    )
    evaluator = _setup_evaluator(device, args, run_path, config, split)
    _eval(evaluator, model, logger, epoch)


def _prepare(cur_id, config, bs, split, epochs):
    device = utils.setup_device()
    run_path = utils.find_path_of_id(cur_id)
    log_path = os.path.join(run_path, f'eval_{split}.log')
    Path(log_path).unlink(missing_ok=True)
    utils.set_logger(verbosity=config.logger_verbosity, log_path=log_path)
    logger = logging.getLogger('custom')
    logger.info(f'\nStarting evaluation:\n'
                f'run_path: {run_path}\n'
                f'split: {split}')
    checkpoint, args = load_checkpoint(run_path, epochs)
    model_package = methods.get_package(args.model)
    model = methods.define_model(model_package, args, device, checkpoint)

    if bs:
        args.eval_bs = bs
    logger.info(f'\nArgs:\n{utils.get_args_as_string(args)}\n'
                f'run_path: {run_path}')
    return args, run_path, device, model, logger, checkpoint


def _setup_evaluator(device, args, run_path, config, split=None):
    config.run_path, config.run_id = run_path, args.run_id
    evaluator = get_evaluator(args.dset_name, args.model)
    evaluator = evaluator(split=split,
                          args=args,
                          config=config,
                          device=device,
                          debug=True if 'debug' in args.run_id else False)
    return evaluator


def _eval(evaluator, model, logger, epoch):
    with utils.Timer(f'Evaluation'):
        evaluator.evaluate(model, epoch)
    logger.info('Evaluation has finished.')
    utils.close_logger()
    wandb.finish()


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--id', default='2021-11-20_T_23-18-12.060479_debug',
                   help="id to evaluate")
    p.add_argument('--split', default='val', choices=['train', 'val', 'test'],
                   help='split used for evaluation')
    p.add_argument('--bs', default=1,
                   help='Batch size for likelihood computation'
                        'If not specified, code uses eval_bs from original '
                        'hyperparameters')
    p.add_argument('--gpu', type=int, default=0)
    p.add_argument('--epoch', default=750,
                   help='Specify training epochs for model to be loaded.')
    parser = p.parse_args()

    eval_id(parser.id, parser.epoch, parser.bs, parser.split)
