import argparse
import heapq
import json
import os
from pathlib import Path
from types import SimpleNamespace

import torch

import utils
from data.flowers.main_raw import load_flowers_data
from disentanglement_vae.evaluator.evaluator import Evaluator
from disentanglement_vae.misc import get_args
from disentanglement_vae.models import MoeDisentanglementVae
from disentanglement_vae.trainer import Trainer
from hyperparams.load import get_config


def train_and_evaluate(args, debug=False):
    device = utils.setup_device()
    if resume_id := vars(args).get('resume_id'):
        run_path = utils.find_path_of_id(resume_id)
        checkpoint, args = load_checkpoint(run_path)
    else:
        checkpoint = None
        args, run_path = utils.setup_dir(args, config, parser.debug)
        utils.set_seeds(args.seed)

    logger = utils.set_logger(verbosity=config.logger_verbosity,
                              log_path=os.path.join(run_path, 'train.log'))
    logger.info(f'\n{"-" * 100}\nrun_path:\n{run_path}')
    logger.info(f'\nArgs:\n{utils.get_args_as_string(args)}')
    _train_and_evaluate(run_path, args, device, checkpoint, debug)
    logger.info('\nModel has been trained and evaluated.')
    utils.close_logger()


def _train_and_evaluate(run_path, args, device, checkpoint, debug=False):
    model = MoeDisentanglementVae(args)
    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    _, dataloader = load_flowers_data(mode='train', batch_size=args.bs)
    evaluators = {}
    for split in ['val']:  #['train', 'val']:
        evaluators[split] = Evaluator(
            run_path=run_path, split=split, args=args, device=device,
            debug=debug
        )
    trainer = Trainer(
        model=model, dataloader=dataloader, args=args, run_path=run_path,
        evaluators=evaluators, debug=debug, device=device,
        checkpoint=checkpoint
    )
    with utils.Timer(f'Train disentanglement VAE on Flower dataset',
                     event_frequency='low'):
        trainer.train()


def load_checkpoint(run_path, epoch=None):
    model_dir = os.path.join(run_path, 'models')
    if epoch:
        ckpt_path = os.path.join(model_dir, f'model_epoch_{epoch}.pt')
    else:
        # Use checkpoint with highest epoch number
        heap = []
        for cur_d in Path(model_dir).iterdir():
            epoch = int(cur_d.stem.split('_')[-1])
            heapq.heappush(heap, (-epoch, cur_d))
        _, ckpt_path = heapq.heappop(heap)
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f'Checkpoint path {ckpt_path} does not exist')
    print(f'Loading checkpoint: {ckpt_path}')
    ckpt = utils.torch_load(ckpt_path)
    args = torch.load(os.path.join(run_path, 'args.pt'))
    # Run_path and run_id may have become inconsistent
    args.run_id = Path(run_path).name
    print(f'\nArgs:\n{utils.get_args_as_string(args)}')
    return ckpt, args


def use_hps_from_queue(debug):
    queue_dir = os.path.join(
        config.dirs['experiments'], 'working', 'run', 'queue_mdvae')
    print(f'Iterating over args from {queue_dir}:')
    for cur_dir in Path(queue_dir).iterdir():
        try:
            with open(cur_dir) as json_file:
                args = json.load(json_file)
            os.remove(cur_dir)
        except FileNotFoundError:
            # File might have been already handled by other worker
            continue
        args = SimpleNamespace(**args)
        train_and_evaluate(args, debug)
    print(f'\nNo args left in {queue_dir}')


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--debug', action='store_true',
                   help="Save results in debug/ to avoid spamming.")
    p.add_argument('--hp_from_queue', action='store_true',
                   help="Load hyperparameters from queue.")
    p.add_argument('--resume_id', default='')
    parser = p.parse_args()
    config = get_config()
    debug = utils.check_debug_status(parser.debug, parser.resume_id)

    if parser.hp_from_queue:
        print('Loading hyperparameters from queue.')
        use_hps_from_queue(debug)
    else:
        if parser.resume_id:
            args = SimpleNamespace(resume_id=parser.resume_id)
        else:
            print('Loading default hyperparameters.')
            args = get_args()
        train_and_evaluate(args, debug)
