"""Performs higher order error correction for specified checkpoints at each order
"""
import os
import torch
import hydra
import wandb
import random
import numpy as np

from error_correction.utils import logger
from error_correction.diffeqs import *
from error_correction.models import NNDESolver
from error_correction.trainers import OperatorTrainer, FullBatchTrainer


def set_seed(seed):
    """Fix the random seed for reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def can_skip(save_dir, fname, order, checkpoints=[]):
    """Utility function for getting all possible model save paths for a specific config
    
    Returns true if the correct model/checkpoint files already exist
    """
    order_str = f"order-{order}_" if order > 0 else ""
    paths = [f"{save_dir}/{order_str}{fname}.pt"]
    paths.extend([
        f"{save_dir}/{order_str}checkpoint-{c}_{fname}.pt" 
        for c in checkpoints
    ])
    return paths, all(os.path.isfile(f) for f in paths)


def _multiorder_correction(trainer, hparams, order=0):
    """Recursively perform error correction on checkpoints

    Example: (epochs = 100, orders = 2, injection_points = {0: [0.1, 0.5], 1: [0.1, 0.5]})

    The checkpoint-error correction recursion tree looks like -


    order 0 [N]                         100, checkpoint @ (10, 50)
                                        /                       \
    order 1 [N_e1]      90, checkpoint @ (9, 45)        50, checkpoint @ (5, 25)
                        /                      \        /                      \
    order 2 [N_e2]     81                      45      45                      25


    The i-th edge in a level (from left-to-right) corresponds to loading the i-th checkpointed model
    from the level above it. Training of the base model (N) or error models (N_e{i}) happen at the nodes.
    The first number at each node denotes the amount of epochs to train for, each being equal to the parent
    epochs minus the i-th parent checkpoint.
    """
    model = trainer.model

    # reinitialize wandb logging if applicable
    if hparams.wandb.log:
        run = wandb.init(reinit=True, project=hparams.wandb.project_name, group=trainer.enc)

    # *** train the base/error correction models ***
    # NOTE: if checkpoint files already exist in the saved_models directory,
    # skip this training run
    existing_checkpoints, skip = can_skip(trainer.save_dir, trainer.details, order, hparams.checkpoints)
    if not skip:  # not all files are present in the model save_dir -> train
        trainer.train()
    else:  # all .pt files for this config are present, load them
        logger.info(f"*** Order-{trainer.model.order} checkpoints found. Skipping training. ***\n")

        # load the checkpoints
        model.load_model(order, existing_checkpoints[0])

        # wrap-up procedure similar to Trainer.finish_training
        outputs = trainer.evaluate()
        trainer.visualize(outputs['prediction'])
        trainer.postprocess()

    # finish wandb process
    if hparams.wandb.log:
        run.finish()

    # must keep track of path in tree - encoded via epochs
    # e.g. [order N prefix] = [order 0 prefix]-[order 0 epochs]-...-[order N-1 epochs]
    if len(hparams.injection_points.get(order, [])) > 0:
        hparams.model.prefix = f"{hparams.model.prefix}-{hparams.epochs}"

    for remaining, cpt in zip(
        [hparams.epochs - cpt for cpt in hparams.checkpoints],
        hparams.checkpoints  # only way to avoid dynamically overwriting this variable
    ):
        # get the correct checkpoint model path
        model_fname = f"checkpoint-{cpt}_{trainer.details}.pt"
        if order > 0:
            model_fname = f"order-{order}_{model_fname}"

        # new trainer for error correction, set custom epoch if applicable
        # NOTE: handle the checkpoint loading internally within the trainer
        hparams.epochs = remaining
        hparams.checkpoints = [
            int(hparams.epochs * ijpt) 
            for ijpt in hparams.injection_points.get(order+1, [])
        ]

        # REQUIRED: set the current order of correction number in the model
        model.set_order(order+1)

        next_trainer = (
        OperatorTrainer if hparams.optimizer.name != 'LBFGS' else FullBatchTrainer
        )(
            model, 
            trainer.dataset, 
            trainer.test_dataset, 
            hparams,
            load_checkpoint=(
                order, 
                f"{trainer.save_dir}/{model_fname}"
            )   
        )

        # recursively error correct each checkpoint
        _multiorder_correction(next_trainer, hparams, order+1)

    # reset the prefix - going up one level in the tree
    if order < hparams.orders:
        hparams.model.prefix = '-'.join(hparams.model.prefix.split('-')[:-1])

    return


def multiorder_correction(dataset, hparams):
    """Wrapper for _multiorder_correction
    """
    train_ds = dataset(hparams, mode='train')
    test_ds = dataset(hparams, mode='test')

    # set up base model checkpoint points
    injection_points = hparams.injection_points
    hparams.checkpoints = [int(hparams.epochs * ijp) for ijp in injection_points.get(0, [])]

    # initialize the model
    model = NNDESolver(
        hparams, 
        orders=hparams.orders,
        reparam_fn=getattr(train_ds, 'reparameterize', None),
    )
    trainer = (
        OperatorTrainer if hparams.optimizer.name != 'LBFGS' else FullBatchTrainer
    )(model, train_ds, test_ds, hparams)

    # for remote train logging via weights & biases
    if hparams.wandb.log:
        wandb.init(
            project=hparams.wandb.project_name, 
            entity=hparams.wandb.entity,
            group=trainer.enc
        )

    # error correction multiple orders at specified checkpoints
    _multiorder_correction(trainer, hparams)


@hydra.main(config_path='../config', config_name='config')
def main(hparams):

    # set random seed everywhere for reproducibility
    set_seed(hparams.random_seed)

    # get the datasets and loaders
    dataset = {
        'nPBE': NonlinearPBE,
        'HenonHeiles': HenonHeiles,
        'NonlinearOscillator': NonlinearOscillator
    }[hparams.data.name]

    # check params
    assert hparams.orders >= len(hparams.injection_points), (
        f"Total orders {hparams.orders} cannot be less than injection_points parameter dict"
    )
    
    multiorder_correction(dataset, hparams)
    

if __name__ == '__main__':
    main()
    
