import os
import sys
import argparse


import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from random import SystemRandom
from utils import get_logger, RunningAverageMeter, _ascent_monotonically, visualize_cnf, get_batch
import time
import pandas as pd
from datetime import datetime
from acnf import CNF


ITERATION_TO_SAVE = [10, 100, 200, 500, 750, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000, 3250, 3500, \
                        3750, 4000, 4250, 4500, 4750, 5000, 5250, 5500, 5750, 6000, 6250, 6500, 6750, 7000, 7250, \
                            7500, 7750, 8000, 8250, 8500, 8750, 9000, 9250, 9500, 9750, 10000]



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--adjoint', action='store_true', default=False)
    parser.add_argument('--viz', action='store_true', default=False)
    parser.add_argument('--niters', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--num_samples', type=int, default=512)
    parser.add_argument('--dataset', type=str, default="circle", choices=["circle", "moons", "blobs", "8blobs", "checkerboard", "olympics"])
    parser.add_argument('--width', type=int, default=64)
    parser.add_argument('--hidden_dim', type=int, default=32)
    parser.add_argument('--gpu', type=int, default=3)
    parser.add_argument('--train_dir', type=str, default="./experiments/acnf_g/")
    parser.add_argument('--load_dir', type=str, default=None)
    parser.add_argument('--T', type=float, default = 10)
    parser.add_argument('--regularization_factor', type=float, default = 0.1)
    # parser.add_argument('--results_dir', type=str, default="./results/cnf/")
    parser.add_argument('--training', action='store_true')
    args = parser.parse_args()

    if args.adjoint:
        from torchdiffeq import odeint_adjoint as odeint
        print("from torchdiffeq import odeint_adjoint as odeint")
    else:
        from torchdiffeq import odeint
        print("from torchdiffeq import odeint")

    T = args.T # training T


    device = torch.device('cuda:' + str(args.gpu)
                          if torch.cuda.is_available() else 'cpu')

    if args.dataset not in ["circle", "moons", "blobs", "8blobs", "checkerboard", "olympics"]:
        raise Warning("Please specify dataset from choices [circle, moons, blobs, 8blobs, checkerboard, olympics]")

    
    # base distribution assumed zero mean, 0.1 variance diagonal Gaussian
    p_z0 = torch.distributions.MultivariateNormal(
        loc=torch.tensor([0.0, 0.0]).to(device),
        covariance_matrix=torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(device)
    )
    # Augment ODE function
    cnf = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width, base_dist=p_z0).to(device)
    optimizer = optim.Adam(cnf.parameters(), lr=args.lr)
    
    # running average of log-likelihood
    loss_meter = RunningAverageMeter()

    iter_init = 1 # DEFAULT from iteration 1

    # if load_dir specified, load model saved before, load the checkpoint
    if args.load_dir is not None:
        # return all files contains "ckpt"
        filename_list = [filename for filename in os.listdir(args.load_dir) if "ckpt" in filename]
        if filename_list:
            itrs = [int(filename.split('.')[0].split('_')[-1]) for filename in filename_list]
            iter_max = max(itrs)
            ckpt_path = os.path.join(args.load_dir, 'ckpt_{}.pth'.format(iter_max))
            checkpoint = torch.load(ckpt_path)
            cnf.load_state_dict(checkpoint['func_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            T = checkpoint['T']
            iter_init = iter_max + 1
            
            print('Loaded ckpt from {} at iteration {}'.format(ckpt_path, iter_init))
        else:
            if not args.training:
                raise ValueError("There is no checkpoint file in the specified dir")
        train_dir = args.load_dir
    else:
        # create new training experiment folder
        if args.train_dir is not None:
            experimentID = int(SystemRandom().random()*100000)
            # train_dir: './experiments/acnf_g/moons/{datetime}/experimentID/' 
            train_dir = os.path.join(os.path.join(os.path.join(args.train_dir, args.dataset), datetime.now().strftime(r'%m%d_%H%M%S')), str(experimentID))
            if not os.path.exists(train_dir):
                os.makedirs(train_dir)
            ckpt_path = os.path.join(train_dir, 'ckpt.pth')
    
    # logger
    input_command = sys.argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)

    log_path = os.path.join(train_dir, "training.log") 
    logger = get_logger(logpath=log_path, filepath=os.path.abspath(__file__))
    logger.info(input_command)

    logger.info("Training saved in {}".format(train_dir))
    logger.info("CNF networks width: {}, hidden_dim: {}".format(args.width, args. hidden_dim))
    logger.info("Model architecture: {}".format(cnf))
    logger.info("T specified as {}".format(T))

    t0 = 0
    t1 = T

    # Training args.niters iterations
    if args.training:
        try:
            if iter_init == 1:
                training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Estimated_NLL':[], 'Constraint_loss':[], 'Estiamted_NLL_traj':[],
                   }
            else:
                logger.info("Resume training from iteration {}".format(iter_init))
                # Load csv file
                try:
                    _history = pd.read_csv(os.path.join(train_dir, 'training_history.csv'))
                    training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Estimated_NLL':[], 'Constraint_loss':[], 'Estiamted_NLL_traj':[],
                                            }

                    training_history['Iteration'] = _history['Iteration'].to_list()
                    training_history['Regularization_factor'] = _history['Regularization_factor'].to_list()
                    training_history['Total_loss'] = _history['Total_loss'].to_list()
                    training_history['Estimated_NLL'] = _history['Estimated_NLL'].to_list()
                    training_history['Constraint_loss'] = _history['Constraint_loss'].to_list()
                    training_history['Estiamted_NLL_traj'] = _history['Estiamted_NLL_traj'].to_list()

                except:
                    print("Fail to load training_history.csv")
                    exit()

            regularization_factor = 1e-5
            for itr in range(iter_init, args.niters + 1):
                if regularization_factor < args.regularization_factor:
                    regularization_factor /= 0.99
                start_time = time.time()
                optimizer.zero_grad()

                # sample x [n_samples, 2]
                # have unlimited 
                x = get_batch(args.num_samples, args.dataset).to(device)

                time_load_data = time.time()

                # ----------- Normalization direction for [z(T), logp_est_T(z(T))] ----------- #
                # initial [z(0) = x, logp_est_0(x)=logmu(x)]
                # logp_est_t0 [n_sample, 1]
                logp_est_t0 = p_z0.log_prob(x).to(device).view(args.num_samples, 1)
                # assert logp_est_t0.shape == torch.empty(args.num_samples, 1).shape
                
                z_t_forward, logp_est_t = odeint(cnf.forward_simulate, 
                                (x, logp_est_t0),
                                torch.linspace(t0, t1, 10).type(torch.float32).to(device),
                                atol=1e-5,
                                rtol=1e-5,
                                method='dopri5',
                                adjoint_params = list(cnf.parameters()),
                )
                z_t1, logp_est_t1 = z_t_forward[-1], logp_est_t[-1]

                time_forward_simulation = time.time()                

                # ------------ Generation direction for [z(t), logp(z(t)), grad_logp(z(t)), grad_logmu(z(t))] --------- #
                # simulate [z(t), log p(z(t)), grad log p(z(t)), grad log mu(z(t))] backwards 
                # initial [z(T), log mu(z(T)), grad log p(z(T)), grad log mu(z(T))]
                
                # gradient of log p(z(T), T)
                grad_log_mu_t1 = cnf.gradient_log_base_distribution(z_t1).to(device)

                # continuous version of constraint loss
                z_t_backward, grad_logp_t, grad_logmu_t, constraint_loss_t = odeint(
                    cnf.integrate_constraint,
                    (z_t1, grad_log_mu_t1, grad_log_mu_t1, torch.zeros_like(logp_est_t1)),
                    torch.tensor([t1, t0]).type(torch.float32).to(device), 
                    atol=1e-5,
                    rtol=1e-5,
                    method='dopri5',
                    adjoint_params = list(cnf.parameters())
                )

                constraint_loss = -constraint_loss_t[-1]
                

                # loss function defined as negative estimated logp_x
                # average over batch (-> grad_y = 1/batch_size at t_1)
                loss = -logp_est_t1.mean(0) + regularization_factor*constraint_loss.mean(0)
                loss.backward()


                # all intermediate variable 
                optimizer.step()

                loss_meter.update(loss.item())
                
                logger.info('Iter: {:04d}, regularization factor: {:.4f}, current total loss: {:.4f}, negative log estimated likelihood: {:.4f}, (continuous) constraint loss: {:.4f}, running avg loss: {:.4f}'.format(itr, regularization_factor, loss.item(), \
                                                                                                                        (-logp_est_t1.mean(0)).item(), (constraint_loss.mean(0)).item(), loss_meter.avg))


                if _ascent_monotonically(logp_est_t.mean(1).squeeze()):
                    logger.info('log estimate density DO monotonically increase w.r.t. t (logp_est_t): {}'.format([round(_value.item(), 4) for _value in logp_est_t.mean(1).squeeze().data]))
                else:
                    logger.info('log estimate density NOT monotonically increase w.r.t. t (logp_est_t): {}'.format([round(_value.item(), 4) for _value in logp_est_t.mean(1).squeeze().data]))

                # evaluate estimated NLL trajectory (on [0, 2T] with interval of 0.05)
                with torch.no_grad():
                    _, _logp_est_t = odeint(cnf.forward_simulate, 
                                (x, logp_est_t0),
                                torch.linspace(t0, 2*t1, int(2*t1/0.01)+1).type(torch.float32).to(device),
                                # torch.tensor([t0, t1]).type(torch.float32).to(device), 
                                atol=1e-5,
                                rtol=1e-5,
                                method='dopri5',
                                adjoint_params = list(cnf.parameters()))


                training_history['Iteration'].append(itr)
                training_history['Regularization_factor'].append(regularization_factor)
                training_history['Total_loss'].append(loss.item())
                training_history['Estimated_NLL'].append((-logp_est_t[-1].mean(0)).item())
                training_history['Constraint_loss'].append((constraint_loss.mean(0)).item())
                training_history['Estiamted_NLL_traj'].append([_value.item() for _value in _logp_est_t.mean(1).squeeze().data])

                if itr % 20 == 0:
                    df = pd.DataFrame(training_history, columns = ['Iteration', 'Regularization_factor', 'Total_loss', 'Estimated_NLL', 'Constraint_loss', 'Estiamted_NLL_traj'])
                    df.to_csv(os.path.join(train_dir, 'training_history.csv'))

                if train_dir is not None and itr in ITERATION_TO_SAVE:
                    ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                    torch.save({
                        'func_state_dict': cnf.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'T':T,
                        'args': args,
                    }, ckpt_path)
                    logger.info('Stored ckpt at {}'.format(ckpt_path))

            if train_dir is not None:
                ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                torch.save({
                    'func_state_dict': cnf.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'T':T,
                    'args': args,
                }, ckpt_path)
                logger.info('Stored ckpt at {}'.format(ckpt_path))

        except KeyboardInterrupt:
            if train_dir is not None:
                ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                torch.save({
                    'func_state_dict': cnf.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'T':T,
                    'args': args,
                }, ckpt_path)
                logger.info('Stored ckpt at {}'.format(ckpt_path))
        logger.info('Training complete after {} iters.'.format(itr))
    else:
        # Test mode
        logger.info('Test mode')
        if not os.path.exists(ckpt_path):
            raise ValueError("No model exists to load for testing! Use --training to run new experiments or --load_dir ")
        

    # Visualize result after training/test
    if args.viz:
        visualize_cnf(cnf, os.path.join(train_dir, 'cnf_visualization'),  args.dataset, p_z0, 0, T, device)


