"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from masked_uci_nf import IndepMaskedUCI
import time

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

save_dir = './models'

def main():

    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--dataset',
                        help='dataset to be modeled.',
                        type=str,
                        default='breast')
    parser.add_argument('--batch_size',
                        help='number of images in a mini-batch.',
                        type=int,
                        default=50)
    parser.add_argument('--latent',
                        help='latent distribution.',
                        type=str,
                        default='normal')
    parser.add_argument('--max_iter',
                        help='maximum number of iterations.',
                        type=int,
                        default=1000)
    parser.add_argument('--resume',
                        help='resume from checkpoint.',
                        type=bool,
                        default=False)
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)
    parser.add_argument('--lr',
                        help='initial learning rate.',
                        type=float,
                        default=1e-3)
    parser.add_argument('--momentum',
                        help='beta1 in Adam optimizer.',
                        type=float,
                        default=0.9)
    parser.add_argument('--decay',
                        help='beta2 in Adam optimizer.',
                        type=float,
                        default=0.999)#0.999)
    args = parser.parse_args()

    # model hyperparameters
    resume = args.resume
    dataset = args.dataset
    batch_size = args.batch_size
    latent = args.latent
    max_iter = args.max_iter
    sample_size = args.sample_size
    coupling = 4
    mask_config = 1.

    # optimization hyperparameters
    lr = args.lr
    momentum = args.momentum
    decay = args.decay
    dataset_list = ['breast', 'red', 'white', 'banknote', 'yeast', 'concrete']
    batchdict = {'breast' : 1500, 'red' : 3000, 'white' : 10000, 'banknote' : 3000, 'yeast' : 3000, 'concrete' : 2000}
    for dataset in dataset_list:
        for run_num in range(0, 5):
            model_dir = save_dir + '/' + dataset
            zca = None
            mean = None
            batch_size = batchdict[dataset]
            if dataset == 'breast':
                (full_dim, mid_dim, hidden) = (30, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
            if dataset == 'red':
                (full_dim, mid_dim, hidden) = (12, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
            if dataset == 'white':
                (full_dim, mid_dim, hidden) = (12, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
            if dataset == 'banknote':
                (full_dim, mid_dim, hidden) = (4, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
            if dataset == 'yeast':
                (full_dim, mid_dim, hidden) = (8, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
            if dataset == 'concrete':
                (full_dim, mid_dim, hidden) = (18, 120, 5)
                transform = torchvision.transforms.ToTensor()
                trainset = IndepMaskedUCI(0.5, num_copies=10, dataset=dataset, horizontal_copies=2)
                trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True)
                
            if latent == 'normal':
                prior = torch.distributions.Normal(
                    torch.tensor(0.).to(device), torch.tensor(1.).to(device))
            elif latent == 'logistic':
                prior = utils.StandardLogistic()
            prior_noise = torch.distributions.Normal(
                torch.tensor(0.0).to(device), torch.tensor(1.6).to(device))
            filename = '%s_' % dataset \
                     + 'bs%d_' % batch_size \
                     + '%s_' % latent \
                     + 'cp%d_' % coupling \
                     + 'md%d_' % mid_dim \
                     + 'hd%d_' % hidden \
                     + 'missing_'
            perms = None
            if resume:
                checkpoint_dict = torch.load('./models/uci-breast/resume.tar')
                perms = checkpoint_dict['permutations']
            equalize_perms = (dataset=='concrete')
            

            flow = nice.NICE(prior=prior,
                        prior_noise=prior,
                        coupling=coupling, 
                        in_out_dim=full_dim, 
                        mid_dim=mid_dim, 
                        hidden=hidden, 
                        mask_config=mask_config, permutations_list=perms,equalize_perms=equalize_perms).to(device)
            if resume:
                flow.load_state_dict(checkpoint_dict['model_state_dict'])

            optimizer = torch.optim.Adamax(flow.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0)

                
            
            total_iter = 0
            prev_iter = 0
            train = True
            running_loss = 0

            trainset.init_latents(flow)
            clamp_min= trainset.min
            clamp_max = trainset.max
            print(clamp_min)
            print(clamp_max)
            trainset.reset_latents(flow, batch_size=5000, model_reset=False)

            step=True
            epoch = 0
            scalings_cuda = trainset.stdev.cuda()
            mean = trainset.means

            print(total_iter)
            acceptance_ratio = 0.0
            multiplier = 1
            reset_time = 100
            chain_length = 1000
            prior_epoch = 0
            sample_std = 1.0
            while train:
         
                print(epoch)
                if epoch == max_iter:
                    train = False
                    break

                
                if epoch > 450:
                    sample_std = 1.0
                if (resume or epoch != 0) and epoch % 50 == 0:
                    if epoch < 550:
                        trainset.reset_latents(flow, batch_size=5000, model_reset=True, sample_std=sample_std)
                        start = time.time()
                        acceptance_ratio  = trainset.get_latents(flow,step=True, num_steps=500, sample_std=sample_std)
                    else:
                        trainset.reset_latents(flow, batch_size=5000, model_reset=True, sample_std=sample_std)
                        start = time.time()
                        acceptance_ratio  = trainset.get_latents(flow,step=True, num_steps=1000, sample_std=sample_std)

                    
                
                
                if (not resume) and (epoch <50) and(epoch !=0):
                    trainset.reset_latents(flow, batch_size=5000, model_reset=False)
                    acceptance_ratio = 0.0
                

                start = time.time()
                start_check = 0
                for data,mask, projected_ends, latents, index  in trainloader:
                    start_check += 1
                   
                    flow.train()    # set to training mode

                    optimizer.zero_grad()    # clear gradient tensors

                    inputs = projected_ends#.view(len(projected_ends),28*28)
                    inputs = torch.min(inputs, clamp_max)
                    inputs = torch.max(inputs, clamp_min)
                    inputs = inputs + 0.01*torch.empty(inputs.shape).cuda().normal_(mean=0.0, std=1.0)

                    # log-likelihood of input minibatch
                    loss = -flow(inputs).mean()

                    running_loss += float(loss)

                    # backprop and update parameters
                    loss.backward()
                    optimizer.step()
                    total_iter += 1

                    if epoch % 10 == 0 and start_check == 1:
                        
                        mean_loss = (running_loss) / (total_iter - prev_iter) + torch.log(trainset.stdev).sum()
                        prev_iter = total_iter
                        bit_per_dim = (mean_loss + np.log(256.) * full_dim) \
                                    / (full_dim * np.log(2.))
                        print('iter %s:' % epoch, 
                            'loss = %.3f' % mean_loss, 
                            'bits/dim = %.3f' % bit_per_dim,
                            ' mse = %.3f' % float(((trainset.projected_ends[0:trainset.base_len]-trainset.data[0:trainset.base_len]).mul(1.0-trainset.mask[0:trainset.base_len])**2).sum()/float((1.0-trainset.mask[0:trainset.base_len]).sum())))
                        running_loss = 0.0

                        flow.eval()        # set to inference mode
                        
                        with torch.no_grad():
                            reconstructed = trainset.projected_ends[0:200].view(200, trainset.image_size)
                            reconstructed = torch.min(reconstructed, clamp_max)
                            reconstructed = torch.max(reconstructed, clamp_min)
                            z, _ = flow.f(reconstructed)
                            reconst = flow.g(z).cpu()
                            reconst = (reconst.mul(trainset.stdev) + trainset.means)
                            samples = flow.sample(sample_size).cpu()
                            samples = (samples.mul(trainset.stdev) + trainset.means)
                        
                    
                print(time.time() - start)
                epoch += 1

            print('Finished training!')

            torch.save({
                'total_iter': total_iter, 
                'model_state_dict': flow.state_dict(), 
                'dataset': dataset, 
                'batch_size': batch_size, 
                'latent': latent, 
                'coupling': coupling, 
                'mid_dim': mid_dim, 
                'hidden': hidden, 
                'mask_config': mask_config,
                'permutations': flow.permutations,
                'mean' : mean,
                'scaling' : trainset.stdev,
                'min' : clamp_min,
                'max': clamp_max,
                'masks' : trainset.mask[0:trainset.base_len]}, 
                model_dir + '/' + filename +'_run%d.tar' % run_num)

            print('Checkpoint Saved')

if __name__ == '__main__':
    main(args)
