import argparse
import heapq
import json
import os
from pathlib import Path
from types import SimpleNamespace

import methods
import utils
from data import load_data
from evaluation.load import get_evaluator
from hyperparams.load import get_args, get_config


def train_and_evaluate(args, config, debug, device):
    args, checkpoint = _prepare(args, config, debug)
    trainer = _load_agents(args, config, device, checkpoint, debug)
    _run(args, trainer)


def _prepare(args, config, debug):
    if resume_id := vars(args).get('resume_id'):
        run_path = utils.find_path_of_id(resume_id)
        checkpoint, args = load_checkpoint(run_path)
    else:
        checkpoint = None
        args, run_path = utils.setup_dir(args, config, debug)
        utils.set_seeds(args.seed)
    config.run_id = args.run_id
    config.run_path = run_path
    logger = utils.set_logger(verbosity=config.logger_verbosity,
                              log_path=os.path.join(run_path, 'train.log'))
    logger.info(f'\n{"-" * 100}\nrun_path:\n{run_path}')
    logger.info(f'\nArgs:\n{utils.get_args_as_string(args)}')

    return args, checkpoint


def _load_agents(args, config, device, checkpoint, debug=False):
    model_package = methods.get_package(args.model)
    model = methods.define_model(model_package, args, device, checkpoint)
    _, train_loader = load_data(mode='train', args=args)
    evaluator = get_evaluator(args.dset_name, args.model)
    # Synthetic dataset does not have validation split
    splits = ['train'] if 'synthetic' in args.dset_name else ['train', 'val']
    splits = ['val']  # Speed up training

    evaluators = {}
    for split in splits:
        evaluators[split] = evaluator(
            split=split, args=args, config=config, device=device, debug=debug)

    trainer = methods.define_trainer(
        model_package, args=args, config=config, model=model, loader=train_loader,
        evaluators=evaluators, device=device, checkpoint=checkpoint, debug=debug)

    return trainer


def _run(args, trainer):
    with utils.Timer(f'Train {args.model} on the {args.dset_name} dataset',
                     event_frequency='low'):
        trainer.train(args.sizes['k'])


def load_hp_from_queue_on_disk(debug, device):
    queue_dir = os.path.join(
        config.dirs['experiments'], 'working', 'run', 'queue'
    )
    print(f'Iterating over args from {queue_dir}:')
    for cur_path in Path(queue_dir).iterdir():
        try:
            with open(cur_path) as json_file:
                args = json.load(json_file)
            os.remove(cur_path)
        except FileNotFoundError:
            # File might have been (simultaneously) handled by other worker
            continue
        print(f'Loading file {Path(cur_path).name}')
        args = SimpleNamespace(**args)
        train_and_evaluate(args, config, debug, device)
        print(f'Finished run.')
    print('\nEither finished run(s) or hyperparameter queue is empty.')


def load_checkpoint(run_path, epochs=None):
    ckpt_path = _find_ckpt_path(run_path, epochs)
    print(f'Loading checkpoint: {ckpt_path}')
    checkpoint = utils.torch_load(ckpt_path)
    args = utils.torch_load(os.path.join(run_path, 'args.pt'))
    # Run_path and run_id may have become inconsistent
    args.run_id = Path(run_path).name
    return checkpoint, args


def _find_ckpt_path(run_path, epochs=None):
    model_dir = os.path.join(run_path, 'models')
    if os.path.isdir(model_dir):
        if epochs:
            ckpt_path = os.path.join(model_dir, f'model_epoch_{epochs}.pt')
        else:
            # Use checkpoint with highest epoch number
            heap = []
            for cur_d in Path(model_dir).iterdir():
                epoch = int(cur_d.stem.split('_')[-1])
                heapq.heappush(heap, (-epoch, cur_d))
            _, ckpt_path = heapq.heappop(heap)
    else:
        # Backward compatibility
        ckpt_path = os.path.join(run_path, 'best_model.pt')
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f'Checkpoint path {ckpt_path} does not exist')
    return ckpt_path


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--model', default='multimodal_vae_moe',
                   choices=['multimodal_vae_poe',  # product of experts posterior
                            'multimodal_vae_moe'  # mixture of experts posterior
                            ],
                   help='model name.')
    p.add_argument('--dataset_name', default='flowers',
                   choices=['synthetic_data',
                            'cub_ft',  # CUB with image feature vectors and caption feature vectors
                            'flowers',  # Oxford Flower with images and caption feature vectors
                            'flowers_ft',  # Oxford Flower with image feature vectors and caption feature vectors
                            'cub'  # CUB with images and caption feature vectors
                            ])
    p.add_argument('--debug', action='store_true',
                   help="Save results in debug/ to avoid spamming.")
    p.add_argument('--hp_from_queue', action='store_true',
                   help="Load hyperparameters from queue.")
    p.add_argument('--resume_id', default='')
    parser = p.parse_args()
    config = get_config()
    device = utils.setup_device()
    debug = utils.check_debug_status(parser.debug, parser.resume_id)


    if parser.hp_from_queue:
        print('Loading hyperparameters from queue.')
        load_hp_from_queue_on_disk(debug, device)
    else:
        if parser.resume_id:
            args = SimpleNamespace(resume_id=parser.resume_id)
        else:
            print('Loading default hyperparameters.')
            args = get_args(parser.dataset_name, parser.model)
        train_and_evaluate(args, config, debug, device)
