""" Generic training script. """

import argparse
import datetime
import matplotlib
matplotlib.use('Agg')
from lib.DensityVAE import DensityVAE
from lib.DisentanglementVAE import DisentanglementVAE
from lib.utils import run_cuda_diagnostics
import time
import torch
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from lib.utils import count_pars


def main(arguments=None):

    # parse arguments passed via (i) python script; or (ii) command line;
    parser = get_parser()
    args = parser.parse_args(arguments.split()) if arguments else parser.parse_args()

    # at the moment, the following tasks are supported
    assert args.task in ['density_estimation', 'disentanglement']
    task_to_model_map = dict({'density_estimation': DensityVAE,
                              'disentanglement':    DisentanglementVAE
                              })
    model_name = task_to_model_map[args.task]

    # for reproducibility and synchronization
    pl.seed_everything(args.random_seed)

    # cuda diagnostics
    run_cuda_diagnostics(requested_num_gpus=args.gpus)

    # instantiate model
    model = model_name(**vars(args))
    print("model signature", model.signature())
    print("make checkpoints? ", args.make_checkpoint)
    print("use amp? ", args.amp)
    print("total number of parameters: ", count_pars(model))

    # construct full experiment name which contains hyper parameters and a time-stamp
    full_exp_name = args.exp_name + "_" + datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%f')  \
                                  + "_" + model.signature()

    # checkpoint directory
    checkpoint_dir = os.getcwd()+'/checkpoints/'+args.exp_name
    os.makedirs(checkpoint_dir, exist_ok=True)

    # logs directory
    os.makedirs('logs/', exist_ok=True)

    # load existing, or create a new checkpoint
    detected_checkpoint = None
    if args.use_checkpoint:
        checkpoint_list = os.listdir(checkpoint_dir)
        checkpoint_list.sort(reverse=True)
        for checkpoint in checkpoint_list:
            if checkpoint.startswith(model.signature()):
                detected_checkpoint = checkpoint_dir + "/" + checkpoint
                full_exp_name = "CHK_" + full_exp_name
                print("Checkpoint found.")
                break

    # setup a checkpoint callback
    checkpoint_callback = None
    if args.make_checkpoint:
        checkpoint_callback = ModelCheckpoint(filepath=checkpoint_dir, monitor='val_loss',
                                              prefix=model.signature(), period=args.check_val_every_n_epoch)

    # empty cache now that the model is created
    torch.cuda.empty_cache()

    # train the model
    trainer = pl.Trainer.from_argparse_args(args,
                                            logger=TensorBoardLogger(save_dir="logs/", name=full_exp_name),
                                            progress_bar_refresh_rate=500,
                                            row_log_interval=500,
                                            log_save_interval=500,
                                            distributed_backend=(args.distributed_backend if args.gpus > 1 else None),
                                            terminate_on_nan=True,
                                            checkpoint_callback=checkpoint_callback,
                                            resume_from_checkpoint=detected_checkpoint
                                            )
    trainer.fit(model)


def get_parser():

    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument('--task', type=str, default="density_estimation")
    parser.add_argument('--exp_name', type=str, default="NoName", help='The name of the experiment.')
    parser.add_argument('--z_size', type=int, default=20, help='Number of stochastic feature maps per layer.')
    parser.add_argument('--h_size', type=int, default=160, help='Deterministic layer size.')
    parser.add_argument('--max_scale_sdn', type=int, default=64, help='Maximum scale to apply SDN on.')
    parser.add_argument('--min_scale_sdn', type=int, default=0, help='Minimum scale to apply SDN on.')
    parser.add_argument('--state0', type=int, default=300, help='The state size of bottom most SDN.')
    parser.add_argument('--delta_state', type=int, default=0, help='State difference between image scales.')
    parser.add_argument('--num_dirs', type=int, default=1, help='The number of SDN dirs.')
    parser.add_argument('--iters', type=int, default=1000000, help='Number of training iterations.')
    parser.add_argument('--batch', type=int, default=32, help='Batch size.')
    parser.add_argument('--lrate', type=float, default=0.002, help='Learning rate.')
    parser.add_argument('--free_bits', type=float, default=0, help='KL free-bits.')
    parser.add_argument('--depth', type=int, default=5, help='Num ladder blocks.')
    parser.add_argument('--root', type=str, default="./data/", help='Path to data folder.')
    parser.add_argument('--dataset', type=str, default="CIFAR10", help='Dataset name.')
    parser.add_argument('--eval_iterations', type=int, default=2500, help='Frequency of evaluation logs.')
    parser.add_argument('--post_model', type=str, default="IsoGaussian", help='Probabilistic posterior model.')
    parser.add_argument('--prior_model', type=str, default="IsoGaussian", help='Probabilistic prior model.')
    parser.add_argument('--obs_model', type=str, default="DL", help='Probabilistic observation model.')
    parser.add_argument('--mix_components', type=int, default=30, help='Number of mixture components, for DML.')
    parser.add_argument('--ds_list', nargs='*', type=int, help='Downsampling list.')
    parser.add_argument('--ema_coef', type=float, default=1, help='Exponential moving average coefficient.')
    parser.add_argument('--random_seed', type=int, default=13, help='Random seed.')
    parser.add_argument('--beta_rate', type=float, default=1, help='Beta rate for KL annealing.')
    parser.add_argument('--num_workers', type=int, default=0, help='Num workers.')
    parser.add_argument('--downsample_first', action='store_true', help='Downsample in the first layer of encoder.')
    parser.add_argument('--amp', action='store_true', help='Apply AMP.')
    parser.add_argument('--sampling_temperature', type=float, default=1.0, help='Sampling temperature.')
    parser.add_argument('--make_checkpoint', action='store_true', help='Flag to indicate whether we do checkpointing.')
    parser.add_argument('--use_checkpoint', action='store_true', help='Flag to indicate resuming of training.')
    parser.add_argument('--nbits', type=int, default=8, help='Number of bits per pixel.')
    parser.add_argument('--figsize', type=int, default=10, help='Size of images logged during training.')
    parser.add_argument('--evaluation_mode', action='store_true', help='If model is used only for evaluation.')
    parser.add_argument('--beta_final', type=float, default=1, help='Final beta value.')

    return parser


if __name__ == '__main__':
    main()
