"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from masked_mnist_nf import IndepMaskedMNIST, BlockMaskedMNIST, PatchMaskedMNIST
import time
import json
import os

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

save_dir = './models'
with open('./run_params.json', 'r') as paramfile:
    param_list = json.load(paramfile)
def main():
    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--dataset',
                        help='dataset to be modeled.',
                        type=str,
                        default='mnist')
    parser.add_argument('--batch_size',
                        help='number of images in a mini-batch.',
                        type=int,
                        default=200)
    parser.add_argument('--latent',
                        help='latent distribution.',
                        type=str,
                        default='logistic')
    parser.add_argument('--max_iter',
                        help='maximum number of iterations.',
                        type=int,
                        default=1000)
    parser.add_argument('--resume',
                        help='resume from checkpoint.',
                        type=bool,
                        default=False)
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)
    parser.add_argument('--lr',
                        help='initial learning rate.',
                        type=float,
                        default=1e-3)
    parser.add_argument('--momentum',
                        help='beta1 in Adam optimizer.',
                        type=float,
                        default=0.9)
    parser.add_argument('--decay',
                        help='beta2 in Adam optimizer.',
                        type=float,
                        default=0.01)#0.999)
    args = parser.parse_args()

    # model hyperparameters
    resume = args.resume
    dataset = args.dataset
    batch_size = args.batch_size
    latent = args.latent
    max_iter = args.max_iter
    sample_size = args.sample_size
    coupling = 4
    mask_config = 1.

    # optimization hyperparameters
    lr = args.lr
    momentum = args.momentum
    decay = args.decay
    for params in param_list:
        run_dir = save_dir + '/' + str(params['MASK_FUNCTION']) + '_' + str(params['MASK_PARAM'])
        sample_dir = run_dir + '/samples'
        model_dir = run_dir + '/models'
        recons_dir = run_dir + '/reconstruction'
        try:
            os.mkdir(run_dir)
            os.mkdir(sample_dir)
            os.mkdir(model_dir)
            os.mkdir(recons_dir)
        except:
            print(str(params['MASK_FUNCTION']) + '_' + str(params['MASK_PARAM']) + ' exists')
        zca = None
        mean = None
        
        if dataset == 'mnist':
            (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
            transform = torchvision.transforms.ToTensor()
            if params['MASK_FUNCTION'] == 'indep':
                trainset = IndepMaskedMNIST(params['MASK_PARAM'], train=True)
            if params['MASK_FUNCTION'] == 'block':
                trainset = BlockMaskedMNIST(params['MASK_PARAM'], train=True)
            if params['MASK_FUNCTION'] == 'patch':
                trainset = PatchMaskedMNIST(params['MASK_PARAM'], train=True)
            trainloader = torch.utils.data.DataLoader(trainset,
                batch_size=batch_size, shuffle=True)
         
        if latent == 'normal':
            prior = torch.distributions.Normal(
                torch.tensor(0.).to(device), torch.tensor(1.).to(device))
        elif latent == 'logistic':
            prior = utils.StandardLogistic()
        prior_noise = torch.distributions.Normal(
            torch.tensor(0.0).to(device), torch.tensor(1.6).to(device))
        filename = '%s_' % dataset \
                 + 'bs%d_' % batch_size \
                 + '%s_' % latent \
                 + 'cp%d_' % coupling \
                 + 'md%d_' % mid_dim \
                 + 'hd%d_' % hidden \
                 + 'missing_'
        perms = None
        if resume:
            checkpoint_dict = torch.load('./models/mnist/mnist_bs200_logistic_cp4_md1000_hd5_missing_iter500.tar')
            perms = checkpoint_dict['permutations']
        

        flow = nice.NICE(prior=prior,
                    prior_noise=prior,
                    coupling=coupling, 
                    in_out_dim=full_dim, 
                    mid_dim=mid_dim, 
                    hidden=hidden, 
                    mask_config=mask_config, permutations_list=perms).to(device)
        if resume:
            flow.load_state_dict(checkpoint_dict['model_state_dict'])

        optimizer = torch.optim.RMSprop(flow.parameters(), lr=1e-5, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0.9, centered=False)

            
        
        total_iter = 0
        train = True
        running_loss = 0
        trainset.init_latents(flow)

        clamp_min= trainset.min
        clamp_max = trainset.max
        trainset.reset_latents(flow, batch_size=5000, model_reset=False)

        step=True
        epoch = 0
        scalings_cuda = trainset.stdev.cuda()
        mean = trainset.means

        acceptance_ratio = 0.0
        sample_std = 1.814
        while train:

            if epoch == max_iter:
                train = False
                break

            if epoch > 450:
                sample_std = 0.5
            if (resume or epoch != 0) and epoch % 50 == 0:
                torch.save({
                    'total_iter': total_iter, 
                    'model_state_dict': flow.state_dict(), 
                    'dataset': dataset, 
                    'batch_size': batch_size, 
                    'latent': latent, 
                    'coupling': coupling, 
                    'mid_dim': mid_dim, 
                    'hidden': hidden, 
                    'mask_config': mask_config,
                    'permutations': flow.permutations,
                    'mean' : mean,
                    'scaling' : trainset.stdev,
                    'min' : clamp_min,
                    'max': clamp_max}, 
                    model_dir + '/checkpoint_%i.tar' % epoch)
                if True:
                    trainset.reset_latents(flow, batch_size=5000, model_reset=True)
                    start = time.time()
                    acceptance_ratio  = trainset.get_latents(flow,step=True, num_steps=500, sample_std=sample_std)
            
            
            if (not resume) and (epoch <50) and(epoch !=0):
                trainset.reset_latents(flow, batch_size=5000, model_reset=False)
                acceptance_ratio = 0.0
            
            
            
            if False:
                start = time.time()
                acceptances = trainset.get_latents(flow,step=True, num_steps=10, sample_std=sample_std)
                
                print(acceptances, time.time() - start)
            
            if epoch % 50 == 0:
                with torch.no_grad():
                    reconstructed = trainset.projected_ends[0:200].view(200, 784)
                    reconstructed = torch.min(reconstructed, clamp_max)
                    reconstructed = torch.max(reconstructed, clamp_min)
                    z, _ = flow.f(reconstructed)
                    reconst = flow.g(z).cpu()
                    reconst = utils.prepare_data(
                        reconst, dataset, zca=zca, mean=mean, rescale=trainset.stdev, reverse=True)
                    samples = flow.sample(sample_size).cpu()
                    samples = utils.prepare_data(
                        samples, dataset, zca=zca, mean=mean, rescale=trainset.stdev, reverse=True)
                    torchvision.utils.save_image(torchvision.utils.make_grid(reconst),
                        recons_dir + '/' + filename +'iter%d.png' % epoch)
                    torchvision.utils.save_image(torchvision.utils.make_grid(samples),
                        sample_dir + '/' + filename +'iter%d.png' % epoch)         
            start = time.time()
            for data,mask, labels,projected_ends, latents, index  in trainloader:
                
               
                flow.train()    # set to training mode

                optimizer.zero_grad()    # clear gradient tensors

                inputs = torch.min(projected_ends.view(len(projected_ends),28*28), clamp_max)
                inputs = torch.max(inputs, clamp_min)
                inputs = inputs + ((torch.empty(inputs.shape).cuda().uniform_() - 0.5)/256.0).div(scalings_cuda)

                # log-likelihood of input minibatch
                loss = -flow(inputs).mean()

                running_loss += float(loss)

                # backprop and update parameters
                loss.backward()
                optimizer.step()
                if total_iter % 1000 == 0:             
                    mean_loss = (running_loss) / 1000 + torch.log(trainset.stdev).sum()
                    bit_per_dim = (mean_loss + np.log(256.) * full_dim) \
                                / (full_dim * np.log(2.))
                    print('iter %s:' % epoch, 
                        'loss = %.3f' % mean_loss, 
                        'bits/dim = %.3f' % bit_per_dim)
                    running_loss = 0.0

                    flow.eval()        # set to inference mode
                    
                total_iter += 1
            print(time.time() - start)
            epoch += 1

        print('Finished training!')

        torch.save({
            'total_iter': total_iter, 
            'model_state_dict': flow.state_dict(), 
            'dataset': dataset, 
            'batch_size': batch_size, 
            'latent': latent, 
            'coupling': coupling, 
            'mid_dim': mid_dim, 
            'hidden': hidden, 
            'mask_config': mask_config,
            'permutations': flow.permutations,
            'mean' : mean,
            'scaling' : trainset.stdev,
            'min' : clamp_min,
            'max': clamp_max}, 
            model_dir + '/' + filename +'iter%d.tar' % epoch)

        print('Checkpoint Saved')
    

if __name__ == '__main__':
    main()
