import argparse
import heapq
import json
import logging
import os
from pathlib import Path
from types import SimpleNamespace

import torch

import utils
from hyperparams.load import get_config
from mhvae_vasco.evaluator.load import get_evaluator
from mhvae_vasco.hyperparameters.load import get_args
from mhvae_vasco.misc import load_data
from mhvae_vasco.model.load import get_model
from mhvae_vasco.trainer import Trainer

logger = logging.getLogger('custom')


def train_and_evaluate(args, debug=False):
    device = utils.setup_device()
    if resume_id := vars(args).get('resume_id'):
        run_path = utils.find_path_of_id(resume_id)
        checkpoint, args = load_checkpoint(run_path)
    else:
        checkpoint = None
        args, run_path = utils.setup_dir(args, config, parser.debug)
        utils.set_seeds(args.seed)

    logger = utils.set_logger(verbosity=config.logger_verbosity,
                              log_path=os.path.join(run_path, 'train.log'))
    logger.info(f'\n{"-" * 100}\nrun_path:\n{run_path}')
    logger.info(f'\nArgs:\n{utils.get_args_as_string(args)}')
    _train_and_evaluate(run_path, args, device, debug, checkpoint)
    logger.info('\nModel has been trained and evaluated.')
    utils.close_logger()


def _train_and_evaluate(run_path, args, device, debug, checkpoint):
    model = get_model(args, device)
    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    dataset, loader = load_data(mode='train', args=args)
    evaluators = {}
    for split in ['val']:  #['train', 'val']:
        evaluators[split] = get_evaluator(
            dataset_name=args.dset_name, run_path=run_path, split=split,
            args=args, device=device, debug=debug
        )
    trainer = Trainer(
        model=model, loader=loader, args=args, run_path=run_path,
        evaluators=evaluators, debug=debug, device=device,
        checkpoint=checkpoint
    )
    with utils.Timer(f'Train MHVAE (Vasco et al.) on {args.dset_name} dataset',
                     event_frequency='low'):
        trainer.train()


def load_checkpoint(run_path, epoch=None):
    model_dir = os.path.join(run_path, 'models')
    if epoch:
        ckpt_path = os.path.join(model_dir, f'model_epoch_{epoch}.pt')
    else:
        # Use checkpoint with highest epoch number
        heap = []
        for cur_d in Path(model_dir).iterdir():
            epoch = int(cur_d.stem.split('_')[-1])
            heapq.heappush(heap, (-epoch, cur_d))
        _, ckpt_path = heapq.heappop(heap)
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f'Checkpoint path {ckpt_path} does not exist')
    print(f'Loading checkpoint: {ckpt_path}')
    ckpt = utils.torch_load(ckpt_path)
    args = torch.load(os.path.join(run_path, 'args.pt'))
    # Run_path and run_id may have become inconsistent
    args.run_id = Path(run_path).name
    print(f'\nArgs:\n{utils.get_args_as_string(args)}')
    return ckpt, args


def use_hps_from_queue(debug):
    queue_dir = os.path.join(
        config.dirs['experiments'], 'working', 'run', 'queue_mhvae')
    print(f'Iterating over args from {queue_dir}:')
    for cur_dir in Path(queue_dir).iterdir():
        try:
            with open(cur_dir) as json_file:
                args = json.load(json_file)
            os.remove(cur_dir)
        except FileNotFoundError:
            # File might have been already handled by other worker
            continue
        args = SimpleNamespace(**args)
        train_and_evaluate(args, debug)
    print(f'\nNo args left in {queue_dir}')


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--debug', action='store_true',
                   help="Save results in debug/ to avoid spamming.")
    p.add_argument('--dataset_name', default='flowers',
                   choices=['cub_ft', 'flowers'])
    p.add_argument('--hp_from_queue', action='store_true',
                   help="Load hyperparameters from queue.")
    p.add_argument('--resume_id', default='')
    parser = p.parse_args()
    config = get_config()
    debug = utils.check_debug_status(parser.debug, parser.resume_id)

    if parser.hp_from_queue:
        print('Loading hyperparameters from queue.')
        use_hps_from_queue(debug)
    else:
        if parser.resume_id:
            args = SimpleNamespace(resume_id=parser.resume_id)
        else:
            print('Loading default hyperparameters.')
            args = get_args(parser.dataset_name)
        train_and_evaluate(args, debug)
