import argparse
import datetime
import logging
import os

import utils
from disentanglement_vae.evaluator.evaluator import Evaluator
from disentanglement_vae.models import MoeDisentanglementVae
from disentanglement_vae.run import load_checkpoint
from hyperparams.load import get_config


def evaluate_mdvae(run_id, split, epoch):
    config = get_config()
    device = utils.setup_device()
    run_path = utils.find_path_of_id(run_id)
    utils.set_logger(verbosity=config.logger_verbosity,
                     log_path=os.path.join(run_path, f'eval_{split}.log'))
    logger = logging.getLogger('custom')
    logger.info(f'Starting evaluation with split: {split}; '
                f'time: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")}'
                f'\nrun_path: {run_path}')

    checkpoint, args = load_checkpoint(run_path, epoch)
    model = MoeDisentanglementVae(args)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    evaluator = Evaluator(
        run_path, split=split, args=args, device=device,
        debug=True if 'debug' in run_id else False
    )

    with utils.Timer(f'Evaluation'):
        evaluator.evaluate(model, epoch)
    logger.info('Evaluation has finished.')
    utils.close_logger()


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--run_id', default='2021-11-20_T_23-20-20.573002_debug')
    p.add_argument('--split', default='val', choices=['train', 'val', 'test'])
    p.add_argument('--epoch', default=750)
    parser = p.parse_args()

    evaluate_mdvae(parser.run_id, parser.split, parser.epoch)
