"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from masked_mnist_nf import IndepMaskedMNIST, BlockMaskedMNIST, PatchMaskedMNIST
import time
import json
import os

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

save_dir = './models'
data_dir = save_dir + '/mnist-data'
with open('./run_params.json', 'r') as paramfile:
    param_list = json.load(paramfile)
def main():
    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--dataset',
                        help='dataset to be modeled.',
                        type=str,
                        default='mnist')
    parser.add_argument('--batch_size',
                        help='number of images in a mini-batch.',
                        type=int,
                        default=200)
    parser.add_argument('--latent',
                        help='latent distribution.',
                        type=str,
                        default='logistic')
    parser.add_argument('--max_iter',
                        help='maximum number of iterations.',
                        type=int,
                        default=1000)
    parser.add_argument('--resume',
                        help='resume from checkpoint.',
                        type=bool,
                        default=False)
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)
    parser.add_argument('--lr',
                        help='initial learning rate.',
                        type=float,
                        default=1e-3)
    parser.add_argument('--momentum',
                        help='beta1 in Adam optimizer.',
                        type=float,
                        default=0.9)
    parser.add_argument('--decay',
                        help='beta2 in Adam optimizer.',
                        type=float,
                        default=0.01)#0.999)
    args = parser.parse_args()

    # model hyperparameters
    resume = args.resume
    dataset = args.dataset
    batch_size = args.batch_size
    latent = args.latent
    max_iter = args.max_iter
    sample_size = args.sample_size
    coupling = 4
    mask_config = 1.
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)

    # optimization hyperparameters
    lr = args.lr
    momentum = args.momentum
    decay = args.decay
    for params in param_list:
        run_dir = save_dir + '/' + str(params['MASK_FUNCTION']) + '_' + str(params['MASK_PARAM'])
        sample_dir = run_dir + '/samples'
        model_dir = run_dir + '/models'
        recons_dir = run_dir + '/reconstruction'
        test_dir = run_dir + '/tests'

        try:
            os.mkdir(test_dir)
        except:
            print(str(params['MASK_FUNCTION']) + '_' + str(params['MASK_PARAM']) + ' exists')
        zca = None
        mean = None
        

         
        if latent == 'normal':
            prior = torch.distributions.Normal(
                torch.tensor(0.).to(device), torch.tensor(1.).to(device))
        elif latent == 'logistic':
            prior = utils.StandardLogistic()
        prior_noise = torch.distributions.Normal(
            torch.tensor(0.0).to(device), torch.tensor(1.6).to(device))
        filename = '%s_' % dataset \
                 + 'bs%d_' % batch_size \
                 + '%s_' % latent \
                 + 'cp%d_' % coupling \
                 + 'md%d_' % mid_dim \
                 + 'hd%d_' % hidden \
                 + 'missing_'
        
        checkpoint_dict = torch.load(model_dir + '/' + filename +'iter1000.tar')
        perms = checkpoint_dict['permutations']

        mean = checkpoint_dict['mean']
        scaling = checkpoint_dict['scaling']
        clamp_min = checkpoint_dict['min']
        clamp_max = checkpoint_dict['max']
        

        if dataset == 'mnist':
            (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
            transform = torchvision.transforms.ToTensor()
            if params['MASK_FUNCTION'] == 'indep':
                trainset = IndepMaskedMNIST(params['MASK_PARAM'], train=False, num_copies=5, override=True, mean=mean, std=scaling, clamp_min=clamp_min, clamp_max=clamp_max)
            if params['MASK_FUNCTION'] == 'block':
                trainset = BlockMaskedMNIST(params['MASK_PARAM'], train=False, num_copies=5, override=True, mean=mean, std=scaling, clamp_min=clamp_min, clamp_max=clamp_max)
            if params['MASK_FUNCTION'] == 'patch':
                trainset = PatchMaskedMNIST(params['MASK_PARAM'], train=False, num_copies=5, override=True, mean=mean, std=scaling, clamp_min=clamp_min, clamp_max=clamp_max)
            trainloader = torch.utils.data.DataLoader(trainset,
                batch_size=batch_size, shuffle=True)

        flow = nice.NICE(prior=prior,
                    prior_noise=prior,
                    coupling=coupling, 
                    in_out_dim=full_dim, 
                    mid_dim=mid_dim, 
                    hidden=hidden, 
                    mask_config=mask_config, permutations_list=perms).to(device)
        flow.load_state_dict(checkpoint_dict['model_state_dict'])


            
        
        total_iter = 0
        train = True
        running_loss = 0
        trainset.init_latents(flow)

        clamp_min= trainset.min
        clamp_max = trainset.max
        trainset.reset_latents(flow, batch_size=5000, model_reset=False)

        step=True
        epoch = 0
        scalings_cuda = trainset.stdev.cuda()
        mean = trainset.means

        acceptance_ratio = 0.0
        sample_std = 0.5
        for replication in range(0, 10):
            print(params['MASK_FUNCTION'], params['MASK_PARAM'], replication)
            trainset.reset_latents(flow, batch_size=5000, model_reset=True)
            acceptances = trainset.get_latents(flow,step=True, num_steps=2000, sample_std=sample_std)
            torch.save({
                'completions': trainset.projected_ends, 
                'mask': trainset.mask, 
                'true': trainset.data.data, 
                'mean' : mean,
                'scaling' : scaling,
                'min' : clamp_min,
                'max': clamp_max}, 
                test_dir + '/impute_%d.tar' % replication)


    

if __name__ == '__main__':
    main()
