""" Training ACNF with SNF training objective with Ascent Regularization
"""

import os
import sys
import argparse


import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import pdb
from random import SystemRandom
from utils import get_logger, RunningAverageMeter, _ascent_monotonically, visualize_cnf, get_batch
import time
import pandas as pd
from datetime import datetime
from model import ACNF_SNF


ITERATION_TO_SAVE = [10, 20, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000, 1250, 1500, 1750, 2000, \
                        2250, 2500, 2750, 3000, 3250, 3500, 3750, 4000, 4500, 5000]



class MVN(nn.Module):
    def __init__(self, mixture, mu, stds):
        super(MVN, self).__init__()
        self.mixture = mixture
        self.mu = mu
        self.stds = stds
        
        mix = torch.distributions.Categorical(mixture)
        comp = torch.distributions.Independent(torch.distributions.Normal(mu, stds), 1)
        self.dist = torch.distributions.mixture_same_family.MixtureSameFamily(mix, comp)
        
    
    def dlogprob(self, x, return_numpy = False):
        """ compute score = grad logp(x)
        """
        x = torch.tensor(x, requires_grad=True).to(self.mu.device)
        log_prob = self.log_prob(x)
        
        grad = torch.autograd.grad(log_prob.sum(), x)[0].contiguous()
        
        assert grad.shape == x.shape
        
        if return_numpy:
            return grad.numpy()
        else:
            return grad
    
    def log_prob(self, x, return_numpy = False):
        """ compute logp(x)
        """
        if return_numpy:
            return self.dist.log_prob(x.to(self.mu.device)).detach().cpu().numpy()
        else:
            return self.dist.log_prob(x.to(self.mu.device))
    
    def sample(self, N, return_numpy = False):
        if return_numpy:
            return self.dist.sample(N).detach().cpu().numpy()
        else:
            return self.dist.sample(N)
    



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--adjoint', action='store_true', default=False)
    parser.add_argument('--niters', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--num_samples', type=int, default=512)
    parser.add_argument('--dataset', type=str, default="8blobs", choices=["circle", "moons", "blobs", "8blobs", "checkerboard", "olympics"])
    parser.add_argument('--width', type=int, default=64)
    parser.add_argument('--hidden_dim', type=int, default=32)
    parser.add_argument('--gpu', type=int, default=3)
    parser.add_argument('--train_dir', type=str, default="./experiments/acnf_snf/")
    parser.add_argument('--load_dir', type=str, default=None)
    parser.add_argument('--T', type=float, default = 10)
    parser.add_argument('--regularization_factor', type=float, default = 0.1)
    parser.add_argument('--training', action='store_true')
    args = parser.parse_args()

    if args.adjoint:
        from torchdiffeq import odeint_adjoint as odeint
        print("from torchdiffeq import odeint_adjoint as odeint")
    else:
        from torchdiffeq import odeint
        print("from torchdiffeq import odeint")

    T = args.T # training T


    device = torch.device('cuda:' + str(args.gpu)
                          if torch.cuda.is_available() else 'cpu')


    # base distribution need to support transformed distributions well
    p_z0 = torch.distributions.MultivariateNormal(
        loc=torch.tensor([0.0, 0.0]).to(device),
        covariance_matrix=torch.tensor([[3.**2, 0.0], [0.0, 3.**2]]).to(device)
    )


    # provide target distribution (8blobs)
    if args.dataset is "8blobs":
        n_mixture = 8
        mixture = torch.ones(n_mixture).to(device)
        means = torch.tensor([[2, 0], [-2, 0], [0, -2], [0, 2], \
                            [1.41, 1.41], [-1.41, -1.41], [-1.41, 1.41], [1.41, -1.41]]).to(device)
        stds = 0.3*torch.ones([8, 2]).to(device)
        data_dist = MVN(mixture, means, stds).to(device)
    else:
        raise ValueError("Not implemented")

    # Augment ODE function
    acnf = ACNF_SNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width, base_dist=p_z0, target_dist = data_dist).to(device)
    optimizer = optim.Adam(acnf.parameters(), lr=args.lr)
    
    # running average of log-likelihood
    loss_meter = RunningAverageMeter()

    iter_init = 1 # DEFAULT from iteration 1

    # if load_dir specified, load model saved before, load the checkpoint
    if args.load_dir is not None:
        # return all files contains "ckpt"
        filename_list = [filename for filename in os.listdir(args.load_dir) if "ckpt" in filename]
        if filename_list:
            itrs = [int(filename.split('.')[0].split('_')[-1]) for filename in filename_list]
            iter_max = max(itrs)
            ckpt_path = os.path.join(args.load_dir, 'ckpt_{}.pth'.format(iter_max))
            checkpoint = torch.load(ckpt_path)
            acnf.load_state_dict(checkpoint['func_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            T = checkpoint['T']
            iter_init = iter_max + 1
            
            print('Loaded ckpt from {} at iteration {}'.format(ckpt_path, iter_init))
        else:
            if not args.training:
                raise ValueError("There is no checkpoint file in the specified dir")
        train_dir = args.load_dir
    else:
        # create new training experiment folder
        if args.train_dir is not None:
            experimentID = int(SystemRandom().random()*100000)
            train_dir = os.path.join(os.path.join(os.path.join(args.train_dir, args.dataset), datetime.now().strftime(r'%m%d_%H%M%S')), str(experimentID))
            if not os.path.exists(train_dir):
                os.makedirs(train_dir)
            ckpt_path = os.path.join(train_dir, 'ckpt.pth')
    
    # logger
    input_command = sys.argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)

    log_path = os.path.join(train_dir, "training.log") 
    logger = get_logger(logpath=log_path, filepath=os.path.abspath(__file__))
    logger.info(input_command)

    logger.info("Training saved in {}".format(train_dir))
    logger.info("CNF networks width: {}, hidden_dim: {}".format(args.width, args. hidden_dim))
    logger.info("Model architecture: {}".format(acnf))
    logger.info("T specified as {}".format(T))

    t0 = 0
    t1 = T

    # Training args.niters iterations
    if args.training:
        try:
            if iter_init == 1:
                training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Estimated_NLL':[], 'Constraint_loss':[], 'Estiamted_NLL_traj':[],
                   }
            else:
                logger.info("Resume training from iteration {}".format(iter_init))
                # Load csv file
                try:
                    _history = pd.read_csv(os.path.join(train_dir, 'training_history.csv'))
                    training_history = {'Iteration':[], 'Regularization_factor':[], 'Total_loss':[], 'Estimated_NLL':[], 'Constraint_loss':[], 'Estiamted_NLL_traj':[],
                                           }

                    training_history['Iteration'] = _history['Iteration'].to_list()
                    training_history['Regularization_factor'] = _history['Regularization_factor'].to_list()
                    training_history['Total_loss'] = _history['Total_loss'].to_list()
                    training_history['Estimated_NLL'] = _history['Estimated_NLL'].to_list()
                    training_history['Constraint_loss'] = _history['Constraint_loss'].to_list()
                    training_history['Estiamted_NLL_traj'] = _history['Estiamted_NLL_traj'].to_list()

                except:
                    print("Fail to load training_history.csv")
                    exit()

            # SNF training can use larger regularization to start
            regularization_factor = min(1e-4, args.regularization_factor)
            for itr in range(iter_init, args.niters + 1):
                if regularization_factor < args.regularization_factor:
                    # almost 100 iterations -> *100
                    regularization_factor /= 0.96
                optimizer.zero_grad()

                # sample from base distribution [n_samples, 2]
                z_t0 = p_z0.sample([args.num_samples]).to(device)
                

                # ------------ Generation direction for [z(t), logp(z(t)), grad_logp(z(t)), grad_logmu(z(t)), R] --------- #
                # simulate [z(t), log p(z(t)), grad log p(z(t)), R] 
                # initial [z(0), log mu(z(0)), grad log p(z(0)), 0]
                
                log_mu_t0 = p_z0.log_prob(z_t0).to(device)
                # gradient of log p(z(0), 0)
                grad_log_mu_t0 = acnf.gradient_log_base_distribution(z_t0).to(device)

                

                z_t, logp_t, grad_logp_t, constraint_loss_t = odeint(
                    acnf,
                    (z_t0, torch.zeros_like(log_mu_t0), grad_log_mu_t0, torch.zeros_like(log_mu_t0)),
                    torch.tensor([t0, t1]).type(torch.float32).to(device), 
                    atol=1e-5,
                    rtol=1e-5,
                    method='dopri5',
                    adjoint_params = list(acnf.parameters())
                )

                constraint_loss = constraint_loss_t[-1]
                
                # evaluate likelihood on last sample
                log_likelihoods_t1 = acnf.target_dist.log_prob(z_t[-1])

                # log unnormalized weights (estimate Z = sum of unnormlaized weights)
                log_weights_t1 = log_likelihoods_t1 - log_mu_t0 + logp_t[-1]

                # average over batch (-> grad_y = 1/batch_size at t_1)
                loss = -log_weights_t1.mean(0) + regularization_factor*constraint_loss.mean(0)
                loss.backward()

                # all intermediate variable 
                optimizer.step()

                loss_meter.update(loss.item())
    

                logger.info('Iter: {:04d}, regularization factor: {:.4f}, current total loss: {:.4f}, negative log estimated likelihood: {:.4f}, (continuous) constraint loss: {:.4f}, running avg loss: {:.4f}'.format(itr, regularization_factor, loss.item(), \
                                                                                                                        (-log_weights_t1.mean(0)).item(), (constraint_loss.mean(0)).item(), loss_meter.avg))


                # evaluate log density of generated samples under true target distribution
                logger.info('Evaluate generated samples on true target distribution: averaged log-likelihood {} (compare to true data {})'.format((data_dist.log_prob(z_t[-1])).mean(0).item(), \
                                                                                                                                                    data_dist.log_prob(data_dist.sample([args.num_samples])).mean(0).item()))

                # evaluate estimated NLL trajectory (on [0, 2T] with interval of 0.05)
                with torch.no_grad():
                    _z_t,_logp_t = odeint(acnf.forward_simulate, 
                                (z_t0, torch.zeros_like(log_mu_t0)),
                                torch.linspace(t0, 2*t1, int(2*t1/0.01)+1).type(torch.float32).to(device),
                                atol=1e-5,
                                rtol=1e-5,
                                method='dopri5',
                                adjoint_params = list(acnf.parameters()))

                    _log_weights_t = []
                    for _z, _logp in zip(_z_t, _logp_t):
                        # estimate normalization constant Z_k = Z_{k-1})sum(w_k^i)
                        _log_weights = _logp + acnf.target_dist.log_prob(_z) - log_mu_t0
                        _log_weights_t.append(_log_weights)
                    _log_weights_t = torch.stack(_log_weights_t, dim=0)

                    if _ascent_monotonically(_log_weights_t.mean(1).squeeze()):
                        logger.info('log estimate weights DO monotonically increase w.r.t. t (log_weights): ')
                        logger.info('log estimate weights DO monotonically increase w.r.t. t (log_weights): {}'.format([round(_value.item(), 4) for _value in _log_weights_t.mean(1).squeeze().data]))
                    else:
                        logger.info('log estimate weights NOT monotonically increase w.r.t. t (log_weights): {}'.format([round(_value.item(), 4) for _value in _log_weights_t.mean(1).squeeze().data]))
                        logger.info('log estimate weights NOT monotonically increase w.r.t. t (log_weights)')

                training_history['Iteration'].append(itr)
                training_history['Regularization_factor'].append(regularization_factor)
                training_history['Total_loss'].append(loss.item())
                training_history['Estimated_NLL'].append((-log_weights_t1.mean(0)).item())
                training_history['Constraint_loss'].append((constraint_loss.mean(0)).item())
                training_history['Estiamted_NLL_traj'].append([_value.item() for _value in _log_weights_t.mean(1).squeeze().data])


                if itr % 20 == 0:
                    df = pd.DataFrame(training_history, columns = ['Iteration', 'Regularization_factor', 'Total_loss', 'Estimated_NLL', 'Constraint_loss', 'Estiamted_NLL_traj', \
                                    'Total_time', 'Time_forward_sim', 'Time_backward_sim', 'Time_backprop'])
                    df.to_csv(os.path.join(train_dir, 'training_history.csv'))

                if train_dir is not None and itr in ITERATION_TO_SAVE:
                    ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                    torch.save({
                        'func_state_dict': acnf.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'T':T,
                        'base_dist': acnf.base_dist,
                        'target_dist': acnf.target_dist,
                        'args': args,
                    }, ckpt_path)
                    logger.info('Stored ckpt at {}'.format(ckpt_path))

            if train_dir is not None:
                ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                torch.save({
                    'func_state_dict': acnf.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'T':T,
                    'base_dist': acnf.base_dist,
                    'target_dist': acnf.target_dist,
                    'args': args,
                }, ckpt_path)
                logger.info('Stored ckpt at {}'.format(ckpt_path))

        except KeyboardInterrupt:
            if train_dir is not None:
                ckpt_path = os.path.join(train_dir, 'ckpt_{}.pth'.format(itr))
                torch.save({
                    'func_state_dict': acnf.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'T':T,
                    'base_dist': acnf.base_dist,
                    'target_dist': acnf.target_dist,
                    'args': args,
                }, ckpt_path)
                logger.info('Stored ckpt at {}'.format(ckpt_path))
        logger.info('Training complete after {} iters.'.format(itr))
    else:
        # Test mode
        logger.info('Test mode')
        if not os.path.exists(ckpt_path):
            raise ValueError("No model exists to load for testing! Use --training to run new experiments or --load_dir ")
    