import os
import sys
import argparse
import glob
from PIL import Image
import numpy as np
import matplotlib
from tqdm import tqdm

matplotlib.use('agg')
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import pdb
from random import SystemRandom
from utils import get_logger, RunningAverageMeter, AverageMeter, _ascent_monotonically, count_parameters, count_nfe, batch_iter, load_data
import time
import pandas as pd
from datetime import datetime

from model import CNF_VAE, NONLINEARITIES



ITERATION_TO_SAVE = [10, 100, 200, 500, 750, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000, 3250, 3500, \
                        3750, 4000, 4250, 4500, 4750, 5000, 5250, 5500, 5750, 6000, 6250, 6500, 6750, 7000, 7250, \
                            7500, 7750, 8000, 8250, 8500, 8750, 9000, 9250, 9500, 9750, 10000]
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
ndecs = 0

def update_lr(optimizer, n_vals_without_improvement):
    global ndecs
    if ndecs == 0 and n_vals_without_improvement > args.early_stopping // 3:
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr / 10
        ndecs = 1
    elif ndecs == 1 and n_vals_without_improvement > args.early_stopping // 3 * 2:
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr / 100
        ndecs = 2
    else:
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr / 10**ndecs




def load_data(dataset, batch_size, test_batch_size = 1000):

    import torchvision.transforms as transforms
    from torchvision import datasets
    from torch.utils.data import DataLoader

    # prepare MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if dataset == "mnist":
        # train and validation data
        train_data = datasets.MNIST(
            root='../../data',
            train=True,
            download=True,
            transform=transform
        )
        val_data = datasets.MNIST(
            root='../../data',
            train=False,
            download=True,
            transform=transform
        )

        # training and validation data loaders
        train_loader = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True
        )
        val_loader = DataLoader(
            val_data,
            batch_size=test_batch_size,
            shuffle=False
        )
    else:
        raise ValueError("current don't support dataset other than MNIST")
    
    return train_data, val_data, train_loader, val_loader



    


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--adjoint', action='store_true', default=False)
    parser.add_argument('--niters', type=int, default=8000)
    parser.add_argument('--dataset', choices=['mnist', 'cifar10'], type=str, default='mnist')
    parser.add_argument('--layer_type', choices = ["hypernet", "concatnet"], type=str, default="hypernet")
    parser.add_argument('--latent_dim', type=int, default = 16)
    parser.add_argument('--width', type=int, default=8)
    parser.add_argument('--hid_factor', type=int, default=2)
    parser.add_argument("--divergence_fn", type=str, default="hutchinson", choices=["naive", "hutchinson"])
    parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES)
    parser.add_argument('--gpu', type=int, default=3)
    parser.add_argument('--train_dir', type=str, default="./experiments/acnf_vae/")
    parser.add_argument('--load_dir', type=str, default=None)
    parser.add_argument('--T', type=float, default = 1.0)
    parser.add_argument('--regularization_factor', type=float, default = 0.1)
    parser.add_argument('--gradual_r', action='store_true')
    parser.add_argument('--training', action='store_true')
    parser.add_argument('--deepnet', action='store_true')
    parser.add_argument('--fix_vae', action='store_true')
    parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None])
    parser.add_argument('--test_atol', type=float, default=None)
    parser.add_argument('--test_rtol', type=float, default=None)
    parser.add_argument('--early_stopping', type=int, default=30)
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--test_batch_size', type=int, default=None)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--val_freq', type=int, default=20, help="validate model every n iteration")
    parser.add_argument('--log_freq', type=int, default=1, help="log training every n iteration") 

    args = parser.parse_args()


    if args.adjoint:
        from torchdiffeq import odeint_adjoint as odeint
        print("from torchdiffeq import odeint_adjoint as odeint")
    else:
        from torchdiffeq import odeint
        print("from torchdiffeq import odeint")

    T = args.T # training T
    test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size

    device = torch.device('cuda:' + str(args.gpu)
                          if torch.cuda.is_available() else 'cpu')


    if args.dataset not in ['mnist', 'cifar10']:
        raise Warning("Please specify dataset from choices [mnist, cifar10]")

    # data loader 
    train_data, val_data, train_loader, val_loader = load_data(args.dataset, args.batch_size)
    # [batch_size, 1, 28, 28] for mnist
    _input_shape = list([_ for _ in iter(train_loader)][0][0].shape)

    
    # CNF-VAE model
    cnf_vae = CNF_VAE(z_dim = args.latent_dim, input_size = _input_shape[1:], hidden_dim=args.latent_dim*args.hid_factor, width=args.width, \
                        layer_type = args.layer_type, deeper = args.deepnet, activation = args.nonlinearity, divegence_method = args.divergence_fn, device = device, T = args.T).to(device)
    optimizer = optim.Adam(cnf_vae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    # load pre-trained vae
    try:
        vae_ckpt_path = os.path.join('./experiments/acnf_vae', 'vae_ckpt_d{}.pth'.format(args.latent_dim))
        vae_ckpt = torch.load(vae_ckpt_path, map_location=device)
        cnf_vae.vae.load_state_dict(vae_ckpt['func_state_dict'])
        print('Load pretrained VAE model from {}'.format(vae_ckpt_path))
        if args.fix_vae:
            # freeze VAE for now
            for param in cnf_vae.vae.parameters():
                param.requires_grad = False
    except:
        print('fail to load pre-trained vae model')
        exit()

    # loss function setup: construction loss + expectation over log q_T(z(0))
    criterion = nn.BCELoss(reduction='sum')
    
    iter_init = 0 # DEFAULT from iteration 1

    # if load_dir specified, load model saved before, load the checkpoint
    if args.load_dir is not None:
        # return all files contains "ckpt"
        filename_list = [filename for filename in os.listdir(args.load_dir) if "ckpt" in filename]
        if filename_list:
            itrs = [int(filename.split('.')[0].split('_')[-1]) for filename in filename_list]
            iter_max = max(itrs)
            ckpt_path = os.path.join(args.load_dir, 'ckpt_{}.pth'.format(iter_max))
            checkpoint = torch.load(ckpt_path)
            cnf_vae.load_state_dict(checkpoint['func_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            # [temporary fix] overwrite lr to args.lr
            for _g in optimizer.param_groups:
                _g['lr'] = 0.001
            T = checkpoint['T']
            iter_init = iter_max
            
            print('Loaded ckpt from {} at iteration {}'.format(ckpt_path, iter_init))
        else:
            if not args.training:
                raise ValueError("There is no checkpoint file in the specified dir")
        train_dir = args.load_dir
    else:
        # create new training experiment folder
        if args.train_dir is not None:
            train_dir = os.path.join(os.path.join(os.path.join(args.train_dir, args.dataset), datetime.now().strftime(r'%m%d_%H%M%S')))
            
            if not os.path.exists(train_dir):
                os.makedirs(train_dir)
            ckpt_path = os.path.join(train_dir, 'ckpt.pth')
    
    # logger
    input_command = sys.argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)

    log_path = os.path.join(train_dir, "training.log") 
    logger = get_logger(logpath=log_path, filepath=os.path.abspath(__file__))
    logger.info(input_command)

    logger.info("Training saved in {}".format(train_dir))
    # logger.info("CNF-VAE architecture: {}".format(cnf_vae))
    logger.info("Number of trainable parameters of CNF-VAE: {}".format(count_parameters(cnf_vae)))
    logger.info("VAE architecture: {}".format(cnf_vae.vae))
    logger.info("Number of trainable parameters of VAE: {}".format(count_parameters(cnf_vae.vae)))
    logger.info("CNF use {} architecture".format(args.layer_type))
    if args.layer_type == "hypernet":
        logger.info("CNF networks width: {}, hidden_dim: {}".format(args.width, args.latent_dim*args.hid_factor))
    elif args.layer_type == "concatenet":
        logger.info("CNF networks hidden_dim: {}".format(args.latent_dim*args.hid_factor))
    logger.info("CNF architecture: {}".format(cnf_vae.cnf))
    logger.info("Number of trainable parameters of CNF: {}".format(count_parameters(cnf_vae.cnf)))
    logger.info("T specified as {}".format(T))
    logger.info("R gradually increases" if args.gradual_r else "R set as {} fixed".format(args.regularization_factor))

    t0 = 0
    t1 = T

    # Training args.niters iterations
    if args.training:
        
        if iter_init == 0:
            # new training 
            training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Reconstruction_loss':[], 'Estimated_KL':[], 'Constraint_loss':[], 'ELBO':[],
                                        'Total_time': [], 'Time_forward_sim': [], 'Time_backprop':[]}
            if args.regularization_factor > 1e-5 and args.gradual_r:
                # initial regularization factor if using gradual increasing r to train
                regularization_factor = 1e-5
            else:
                regularization_factor = args.regularization_factor
        else:
            logger.info("Resume training from iteration {}".format(iter_init))
            # Load old training history csv file
            try:
                _history = pd.read_csv(os.path.join(train_dir, 'training_history.csv'))
                training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Reconstruction_loss':[], 'Estimated_KL':[], 'Constraint_loss':[], 'ELBO':[],
                                        'Total_time': [], 'Time_forward_sim': [], 'Time_backprop':[]}

                training_history['Iteration'] = _history['Iteration'].to_list()
                training_history['Regularization_factor'] = _history['Regularization_factor'].to_list()
                training_history['Total_loss'] = _history['Total_loss'].to_list()
                training_history['Reconstruction_loss'] = _history['Reconstruction_loss'].to_list()
                training_history['Estimated_KL'] = _history['Estimated_KL'].to_list()
                training_history['Constraint_loss'] = _history['Constraint_loss'].to_list()
                training_history['ELBO'] = _history['ELBO'].to_list()
                training_history['Total_time'] = _history['Total_time'].to_list()
                training_history['Time_forward_sim'] = _history['Time_forward_sim'].to_list()
                training_history['Time_backprop'] = _history['Time_backprop'].to_list()

                regularization_factor = training_history['Regularization_factor'][-1]
            except:
                print("Fail to load training_history.csv")
                training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Reconstruction_loss':[], 'Estimated_KL':[], 'Constraint_loss':[], 'ELBO':[],
                                        'Total_time': [], 'Time_forward_sim': [], 'Time_backprop':[]}

                regularization_factor = args.regularization_factor

        itr = iter_init
        # running average of log-likelihood
        loss_meter = RunningAverageMeter(0.98) # total loss
        kl_meter = RunningAverageMeter(0.98) # negative likelihood
        reguarlization_meter = RunningAverageMeter(0.98) # regularization loss
        reconstruction_meter = RunningAverageMeter(0.98) # reconstruction loss
        elbo_meter =  RunningAverageMeter(0.98) # ELBO
        time_meter = RunningAverageMeter(0.98)
        nfef_meter = RunningAverageMeter(0.98) # forward NFE
        nfeb_meter = RunningAverageMeter(0.98) # backward NFE
        forward_time_meter = RunningAverageMeter(0.98)
        backprop_time_meter = RunningAverageMeter(0.98) 
        total_time_meter = RunningAverageMeter(0.98) # total time 

        n_vals_without_improvement = 0
        i_epoch = 0
        best_loss = float('inf')
        
        try:
            while True or itr > arg.niters:
                # when training has been improved, stop it
                if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping:
                    break
                i_epoch += 1
                # for _, (x, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
                for _, (x, _) in enumerate(train_loader):
                    cnf_vae.train()
                    start_time = time.time()
                    itr += 1
                    if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping:
                        break

                    if regularization_factor < args.regularization_factor and args.gradual_r:
                        regularization_factor /= 0.99

                    
                    optimizer.zero_grad()
                    # data on gpu?
                    x = x.view(-1, *cnf_vae.vae.input_size).to(device)
                    time_load_data = time.time()
                    # collect reconstruction, estimate on generation [0, T]
                    reconstruction, delta_logp_t0, old_kld, constraint_loss = cnf_vae(x)

                    time_forward_simulation = time.time()
                    nfe_forward = count_nfe(cnf_vae.cnf)
                    nfef_meter.update(nfe_forward)
                    
                    # bce loss: take sum of all dim of [28, 28]
                    bce_loss = criterion(reconstruction, x)/x.shape[0]
                    # total loss = reconstruction loss + KLD - expectation over int delta_logp + constraint loss
                    new_kl = old_kld.mean(0) - delta_logp_t0.mean(0)
                    regularization_loss = constraint_loss.mean(0)
                    neg_elbo = bce_loss + new_kl
                    loss = neg_elbo + regularization_factor * regularization_loss


                    loss_meter.update(loss.item())
                    reconstruction_meter.update(bce_loss.item())
                    kl_meter.update(new_kl.item())
                    elbo_meter.update(-neg_elbo.item())
                    reguarlization_meter.update(regularization_loss.item())

                    loss.backward()

                    time_backward_prop = time.time()
                    nfe_total = count_nfe(cnf_vae.cnf)
                    # Note: backward with more augmented states
                    nfe_backward = nfe_total - nfe_forward                    
                    nfeb_meter.update(nfe_backward) 

                    optimizer.step()

                    total_time_meter.update(time.time() - start_time)
                    forward_time_meter.update(time_forward_simulation-time_load_data)
                    backprop_time_meter.update(time_backward_prop-time_forward_simulation)

                    
                    if itr % args.log_freq == 0:

                        logger.info('Iter: {:04d} | Epoch {:.2f} | LR: {:5f} | R: {:.5f} | total loss: {:.4f} ({:.4f}) | '
                                        'Reconstruction loss: {:.4f} ({:.4f}) | KL: {:.4f} ({:.4f}) | ELBO: {:.4f} ({:.4f}) | (continuous) Reg loss: {:.4f} ({:.4f}) |'
                                        'NFE forward {:.0f}({:.1f}) | NFE Backward {:.0f}({:.1f})'.format(itr, \
                                                i_epoch, optimizer.state_dict()['param_groups'][0]['lr'], 
                                                regularization_factor, loss.item(), loss_meter.avg, 
                                                reconstruction_meter.val, reconstruction_meter.avg, 
                                                kl_meter.val, kl_meter.avg, \
                                                elbo_meter.val, elbo_meter.avg, \
                                                regularization_loss.item(), reguarlization_meter.avg, nfe_forward, nfef_meter.avg, \
                                                nfe_backward, nfeb_meter.avg))
                        

                        logger.info('Time elapse for 1 iteration: {:.4f} ({:.4f}) | forward simulation: {:.4f} ({:.4f}) | '
                                    'backprop: {:.4f} ({:.4f})'.format(total_time_meter.val, total_time_meter.avg, \
                                            forward_time_meter.val, forward_time_meter.avg, \
                                                backprop_time_meter.val, backprop_time_meter.avg))


                        training_history['Iteration'].append(itr)
                        training_history['Regularization_factor'].append(regularization_factor)
                        training_history['Total_loss'].append(loss.item())
                        training_history['Reconstruction_loss'].append(bce_loss.item())
                        training_history['Estimated_KL'].append(new_kl.item())
                        training_history['ELBO'].append(-neg_elbo.item())
                        training_history['Constraint_loss'].append(regularization_loss.item())
                        training_history['Total_time'].append(total_time_meter.val)
                        training_history['Time_forward_sim'].append(forward_time_meter.val)
                        training_history['Time_backprop'].append(backprop_time_meter.val)
                    
                    if train_dir is not None and itr in ITERATION_TO_SAVE:
                        ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                        torch.save({
                            'func_state_dict': cnf_vae.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'T':T,
                            'args': args
                        }, ckpt_path)
                        logger.info('Stored ckpt at {}'.format(ckpt_path))


                    if itr % args.val_freq == 0:
                        # validate model on validation data
                        cnf_vae.eval()
                        cnf_vae.cnf.method = "naive"
                        with torch.no_grad():
                            val_loss_meter = AverageMeter()
                            val_kl_meter = AverageMeter()
                            val_regularization_loss_meter = AverageMeter()
                            val_recosntruction_loss_meter = AverageMeter()
                            val_elbo_meter = AverageMeter()
                            val_nfe_forward = AverageMeter() # monitor forward nfe for validation
                            for _, (x, _) in enumerate(val_loader):
                                x = x.view(-1, *cnf_vae.vae.input_size).to(device)

                                reconstruction, delta_logp_t0, old_kld, constraint_loss = cnf_vae(x)
                                val_nfe_forward.update(count_nfe(cnf_vae.cnf))

                                # total loss = reconstruction loss + KLD - expectation over int delta_logp + constraint loss
                                val_bce_loss = criterion(reconstruction, x)/x.shape[0]
                                new_kl =old_kld.mean(0) - delta_logp_t0.mean(0)
                                regularization_loss = constraint_loss.mean(0)
                                val_neg_elbo = val_bce_loss + new_kl                                
                                val_regularization_loss = constraint_loss.mean(0)
                                val_loss = val_neg_elbo + args.regularization_factor * val_regularization_loss

                                val_loss_meter.update(val_loss.item(), x.shape[0])
                                val_recosntruction_loss_meter.update(val_bce_loss.item(), x.shape[0])
                                val_kl_meter.update(new_kl.item(), x.shape[0])
                                val_elbo_meter.update(-val_neg_elbo.item(), x.shape[0])
                                val_regularization_loss_meter.update(val_regularization_loss.item(), x.shape[0])
                                

                        if val_loss_meter.avg < best_loss:
                            best_loss = val_loss_meter.avg
                            torch.save({
                                'args': args,
                                'func_state_dict': cnf_vae.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'T': T,
                            }, os.path.join(train_dir, 'ckpt.pth'))
                            n_vals_without_improvement = 0
                        else:
                            n_vals_without_improvement += 1
                        update_lr(optimizer, n_vals_without_improvement)
                        cnf_vae.train()
                        cnf_vae.cnf.method = args.divergence_fn
                        
                        logger.info('[VAL] Iter {:06d} | Total Val Loss {:.6f} (R =  {:.4f}) | Reconstruction loss: {:.4f} | Estimated_KL: {:.4f} | ELBO: {:.4f} |'
                                    ' Reg loss: {:.4f} | NFE forward {:.0f} | '
                                    'NoImproveEpochs {:02d}/{:02d}'.format(
                                    itr, val_loss_meter.avg, args.regularization_factor, val_recosntruction_loss_meter.avg, val_kl_meter.avg, val_elbo_meter.avg, \
                                        val_regularization_loss_meter.avg, val_nfe_forward.avg, n_vals_without_improvement, args.early_stopping))
                        
                        # save training history
                        df = pd.DataFrame(training_history, columns = ['Iteration', 'Regularization_factor', 'Total_loss', 'Reconstruction_loss', 'Estimated_KL', 'ELBO', 'Constraint_loss', \
                                        'Total_time', 'Time_forward_sim', 'Time_backprop'])
                        df.to_csv(os.path.join(train_dir, 'training_history.csv'))




        except KeyboardInterrupt:
            if train_dir is not None:
                ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                torch.save({
                    'args': args,
                    'func_state_dict': cnf_vae.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'T':T,
                }, ckpt_path)
                logger.info('Stored ckpt at {}'.format(ckpt_path))
            
            # save training history
            df = pd.DataFrame(training_history, columns = ['Iteration', 'Regularization_factor', 'Total_loss', 'Reconstruction_loss', 'Estimated_KL', 'ELBO', 'Constraint_loss', \
                                        'Total_time', 'Time_forward_sim', 'Time_backprop'])
            df.to_csv(os.path.join(train_dir, 'training_history.csv'))
            
        logger.info('Training complete after {} iters.'.format(itr))
        if train_dir is not None:
            ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
            torch.save({
                'args': args,
                'func_state_dict': cnf_vae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'T':T,
            }, ckpt_path)
            logger.info('Stored ckpt at {}'.format(ckpt_path))
            # save training history
            df = pd.DataFrame(training_history, columns = ['Iteration', 'Regularization_factor', 'Total_loss', 'Reconstruction_loss', 'Estimated_KL', 'ELBO', 'Constraint_loss', \
                                        'Total_time', 'Time_forward_sim', 'Time_backprop'])
            df.to_csv(os.path.join(train_dir, 'training_history.csv'))
    else:
        # Test mode
        logger.info('Test mode')
        if not os.path.exists(ckpt_path):
            raise ValueError("No model exists to load for testing! Use --training to run new experiments or --load_dir ")