"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from masked_mnist_nf import IndepMaskedMNIST, BlockMaskedMNIST
import time

def main(args):
    device = torch.device("cuda:0")

    # model hyperparameters
    resume = args.resume
    dataset = args.dataset
    batch_size = args.batch_size
    latent = args.latent
    max_iter = args.max_iter
    sample_size = args.sample_size
    coupling = 4
    mask_config = 1.

    # optimization hyperparameters
    lr = args.lr
    momentum = args.momentum
    decay = args.decay

    zca = None
    mean = None
    if dataset == 'mnist':
        mean = torch.load('./statistics/mnist_mean.pt')
        (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
        transform = torchvision.transforms.ToTensor()
        trainset = IndepMaskedMNIST(0.4)#BlockMaskedMNIST(12, train=True)#
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True)
    elif dataset == 'fashion-mnist':
        mean = torch.load('./statistics/fashion_mnist_mean.pt')
        (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
        transform = torchvision.transforms.ToTensor()
        trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
            train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
    elif dataset == 'svhn':
        zca = torch.load('./statistics/svhn_zca_3.pt')
        mean = torch.load('./statistics/svhn_mean.pt')
        (full_dim, mid_dim, hidden) = (3 * 32 * 32, 2000, 4)
        transform = torchvision.transforms.ToTensor()
        trainset = torchvision.datasets.SVHN(root='~/torch/data/SVHN',
            split='train', download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
    elif dataset == 'cifar10':
        zca = torch.load('./statistics/cifar10_zca_3.pt')
        mean = torch.load('./statistics/cifar10_mean.pt')
        transform = torchvision.transforms.Compose(
        [torchvision.transforms.RandomHorizontalFlip(p=0.5),
         torchvisitransforms.ToTensor()])
        (full_dim, mid_dim, hidden) = (3 * 32 * 32, 2000, 4)
        trainset = torchvision.datasets.CIFAR10(root='~/torch/data/CIFAR10',
            train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
     
    if latent == 'normal':
        prior = torch.distributions.Normal(
            torch.tensor(0.).to(device), torch.tensor(1.).to(device))
    elif latent == 'logistic':
        prior = utils.StandardLogistic()
    prior_noise = torch.distributions.Normal(
        torch.tensor(0.0).to(device), torch.tensor(1.6).to(device))
    filename = '%s_' % dataset \
             + 'bs%d_' % batch_size \
             + '%s_' % latent \
             + 'cp%d_' % coupling \
             + 'md%d_' % mid_dim \
             + 'hd%d_' % hidden \
             + 'missing_'
    perms = None
    if resume:
        checkpoint_dict = torch.load('./models/mnist/mnist_bs200_logistic_cp4_md1000_hd5_missing_iter500.tar')
        perms = checkpoint_dict['permutations']
    
    equalize_perms = (dataset=='concrete')
    
    flow = nice.NICE(prior=prior,
                prior_noise=prior,
                coupling=coupling, 
                in_out_dim=full_dim, 
                mid_dim=mid_dim, 
                hidden=hidden, 
                mask_config=mask_config, permutations_list=perms, equalize_perms=equalize_perms).to(device)
    if resume:
        flow.load_state_dict(checkpoint_dict['model_state_dict'])
    optimizer = torch.optim.RMSprop(flow.parameters(), lr=1e-5, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0.9, centered=False)

        
    
    total_iter = 0
    train = True
    running_loss = 0
    test_mean = trainset.data.data.mean(dim=0).view(1,784).clone().cuda()
    test_std = trainset.data.data.std(dim=0).view(1,784).clone().cuda()
    trainset.init_latents(flow)

    clamp_min= trainset.min
    clamp_max = trainset.max
    trainset.reset_latents(flow, batch_size=5000, model_reset=False)
    print("Mean test: ", ((test_mean - trainset.means.cuda())**2).sum())
    print("std test: ", ((test_std - trainset.stdev.cuda())**2).sum())

    step=True
    epoch = 0
    scalings_cuda = trainset.stdev.cuda()
    mean = trainset.means

    print(total_iter)
    acceptance_ratio = 0.0
    sample_std = 1.814
    while train:
        print(epoch, sample_std)
        if epoch == max_iter:
            train = False
            break

        
        if epoch > 450:
            sample_std = 0.5
        if (resume or epoch != 0) and epoch % 50 == 0:
            torch.save({
                'total_iter': total_iter, 
                'model_state_dict': flow.state_dict(), 
                'dataset': dataset, 
                'batch_size': batch_size, 
                'latent': latent, 
                'coupling': coupling, 
                'mid_dim': mid_dim, 
                'hidden': hidden, 
                'mask_config': mask_config,
                'permutations': flow.permutations,
                'mean' : mean,
                'scaling' : trainset.stdev,
                'min' : clamp_min,
                'max': clamp_max}, 
                './models/mnist/checkpoint_%i.tar' % epoch)
            if True:
                trainset.reset_latents(flow, batch_size=5000, model_reset=True)
                start = time.time()
                acceptance_ratio  = trainset.get_latents(flow,step=True, num_steps=500, sample_std=sample_std)
            
        
        
        if (not resume) and (epoch <50) and(epoch !=0):
            trainset.reset_latents(flow, batch_size=5000, model_reset=False)
            acceptance_ratio = 0.0
        
        
        
        if False:#(epoch > 500):
            start = time.time()
            acceptances = trainset.get_latents(flow,step=True, num_steps=10, sample_std=sample_std)
            
            print(acceptances, time.time() - start)
        
        
        
        start = time.time()
        for data,mask, labels,projected_ends, latents, index  in trainloader:
            
           
            flow.train()    # set to training mode

            optimizer.zero_grad()    # clear gradient tensors

            inputs = torch.min(projected_ends.view(len(projected_ends),28*28), clamp_max)
            inputs = torch.max(inputs, clamp_min)
            inputs = inputs + ((torch.empty(inputs.shape).cuda().uniform_() - 0.5)/256.0).div(scalings_cuda)

            # log-likelihood of input minibatch
            loss = -flow(inputs).mean()

            running_loss += float(loss)

            # backprop and update parameters
            loss.backward()
            optimizer.step()

            if total_iter % 1000 == 0:
                
                mean_loss = (running_loss) / 1000 + torch.log(trainset.stdev).sum()
                bit_per_dim = (mean_loss + np.log(256.) * full_dim) \
                            / (full_dim * np.log(2.))
                print('iter %s:' % epoch, 
                    'loss = %.3f' % mean_loss, 
                    'bits/dim = %.3f' % bit_per_dim)
                running_loss = 0.0

                flow.eval()        # set to inference mode
                
                with torch.no_grad():
                    reconstructed = trainset.projected_ends[0:200].view(200, 784)
                    reconstructed = torch.min(reconstructed, clamp_max)
                    reconstructed = torch.max(reconstructed, clamp_min)
                    z, _ = flow.f(reconstructed)
                    reconst = flow.g(z).cpu()
                    reconst = utils.prepare_data(
                        reconst, dataset, zca=zca, mean=mean, rescale=trainset.stdev, reverse=True)
                    samples = flow.sample(sample_size).cpu()
                    samples = utils.prepare_data(
                        samples, dataset, zca=zca, mean=mean, rescale=trainset.stdev, reverse=True)
                    torchvision.utils.save_image(torchvision.utils.make_grid(reconst),
                        './reconstruction/DeathCheck/' + filename +'iter%d.png' % epoch)
                    torchvision.utils.save_image(torchvision.utils.make_grid(samples),
                        './samples/DeathCheck/' + filename +'iter%d.png' % epoch)
                
            total_iter += 1
        print(time.time() - start)
        epoch += 1

    print('Finished training!')

    torch.save({
        'total_iter': total_iter, 
        'model_state_dict': flow.state_dict(), 
        'dataset': dataset, 
        'batch_size': batch_size, 
        'latent': latent, 
        'coupling': coupling, 
        'mid_dim': mid_dim, 
        'hidden': hidden, 
        'mask_config': mask_config,
        'permutations': flow.permutations,
        'mean' : mean,
        'scaling' : trainset.stdev,
        'min' : clamp_min,
        'max': clamp_max}, 
        './models/mnist/' + filename +'iter%d.tar' % epoch)

    print('Checkpoint Saved')

if __name__ == '__main__':
    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--dataset',
                        help='dataset to be modeled.',
                        type=str,
                        default='mnist')
    parser.add_argument('--batch_size',
                        help='number of images in a mini-batch.',
                        type=int,
                        default=50)
    parser.add_argument('--latent',
                        help='latent distribution.',
                        type=str,
                        default='logistic')
    parser.add_argument('--max_iter',
                        help='maximum number of iterations.',
                        type=int,
                        default=500)
    parser.add_argument('--resume',
                        help='resume from checkpoint.',
                        type=bool,
                        default=False)
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)
    parser.add_argument('--lr',
                        help='initial learning rate.',
                        type=float,
                        default=1e-3)
    parser.add_argument('--momentum',
                        help='beta1 in Adam optimizer.',
                        type=float,
                        default=0.9)
    parser.add_argument('--decay',
                        help='beta2 in Adam optimizer.',
                        type=float,
                        default=0.01)#0.999)
    args = parser.parse_args()
    main(args)
