import argparse
import logging
import os

import utils
from hyperparams.load import get_config
from mhvae_vasco.evaluator.load import get_evaluator
from mhvae_vasco.model.load import get_model
from mhvae_vasco.run import load_checkpoint


def evaluate_mhvae_vasco(run_id, split, epoch):
    config = get_config()
    device = utils.setup_device()
    run_path = utils.find_path_of_id(run_id)
    utils.set_logger(verbosity=config.logger_verbosity,
                     log_path=os.path.join(run_path, f'eval_{split}.log'))
    logger = logging.getLogger('custom')
    logger.info(f'split: {split}\n'
                f'run_path: {run_path}')

    checkpoint, args = load_checkpoint(run_path, epoch)
    model = get_model(args, device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    evaluator = get_evaluator(
        args.dset_name, run_path=run_path, split=split, args=args,
        device=device, debug=True if 'debug' in run_id else False
    )

    with utils.Timer(f'Evaluation'):
        evaluator.evaluate(model, epoch)
    logger.info('Evaluation has finished.')
    utils.close_logger()


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--run_id', default='2021-11-20_T_23-20-20.037043_debug')
    p.add_argument('--split', default='train', choices=['train', 'val', 'test'])
    p.add_argument('--epoch', default=750)
    parser = p.parse_args()

    evaluate_mhvae_vasco(
        run_id=parser.run_id, split=parser.split, epoch=parser.epoch
    )
