"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from matplotlib import pyplot as pl
from masked_mnist_nf import IndepMaskedMNIST, BlockMaskedMNIST, PatchMaskedMNIST

def main(args):

    auxiliary_std = 1e-3
    proposal_std = 0.05
    resample_std = 0.5
    
    
    device = torch.device("cuda:0")

    checkpoint_file = args.checkpoint
    checkpoint_dict = torch.load(checkpoint_file)

    # model hyperparameters
    dataset = checkpoint_dict['dataset']


    batch_size = checkpoint_dict['batch_size']
    latent = checkpoint_dict['latent']
    max_iter = checkpoint_dict['total_iter']
    sample_size = args.sample_size
    coupling = checkpoint_dict['coupling']
    mask_config = checkpoint_dict['mask_config']
    mean = checkpoint_dict['mean']
    scaling = checkpoint_dict['scaling']
    clamp_min = checkpoint_dict['min']
    clamp_max = checkpoint_dict['max']
    mid_dim = checkpoint_dict['mid_dim']
    hidden = checkpoint_dict['hidden']

    zca = None
    if dataset == 'mnist':
        (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
        transform = torchvision.transforms.ToTensor()
        trainset = torchvision.datasets.MNIST(root='~/torch/data/MNIST',
            train=True, download=True, transform=transform)
        testset = BlockMaskedMNIST(12, train=False)#IndepMaskedMNIST(0.4, train=False)#PatchMaskedMNIST(36, train=False)#
        testloader = torch.utils.data.DataLoader(testset,
            batch_size=sample_size, shuffle=True)
     
    if latent == 'normal':
        prior = torch.distributions.Normal(
            torch.tensor(0.).to(device), torch.tensor(1.).to(device))
    elif latent == 'logistic':
        prior = utils.StandardLogistic()

    flow = nice.NICE(prior=prior, prior_noise=prior,
                coupling=coupling, 
                in_out_dim=full_dim, 
                mid_dim=mid_dim, 
                hidden=hidden, 
                mask_config=mask_config, permutations_list = checkpoint_dict['permutations']).to(device)

    flow.load_state_dict(checkpoint_dict['model_state_dict'])

    mask_prob = 0.6
    flow.eval()        # set to inference mode
    sample_std = resample_std
    scaling_cuda = scaling.cuda()
    mean_cuda = mean.cuda()
    repetition_len = 10
    with torch.no_grad():
        for data, masks, _,_,_,_ in testloader:
            inputs = data.view(len(data),1, 28,28)
            masked_input = inputs.repeat(1,3,1,1)
            masked_input[:,0,:,:] = inputs.squeeze().mul(masks) + 0.75*(1.0 - masks)
            masked_input[:,1,:,:] = inputs.squeeze().mul(masks) 
            masked_input[:,2,:,:] = inputs.squeeze().mul(masks) 
            print(masks.shape)
            torchvision.utils.save_image(torchvision.utils.make_grid(inputs.cpu(), nrow=3),
                        './test/' + 'sampleinput.png')
            torchvision.utils.save_image(torchvision.utils.make_grid(masked_input.cpu(), nrow=3),
                        './test/' + 'maskedinput.png')
            progression_len = 10
            progression = inputs.repeat(1,progression_len+1,1,1)

            inputs = utils.prepare_data(
                inputs.cpu(), dataset, zca=zca, mean=mean, rescale=scaling).to(device)
            masks = masks.view([len(masks), 784]).cuda()
            
            inputs = inputs.repeat(repetition_len, 1)
            masks = masks.repeat(repetition_len, 1)

            original_proposals = torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda()
            samples = utils.prepare_data(flow.sample(30).cpu(), dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)
            samples = torch.clamp(samples, 0.0, 1.0)
            torchvision.utils.save_image(torchvision.utils.make_grid(samples, nrow=6), './test/' + 'samples.png')
            projected_end = inputs.mul(masks) + flow.g(original_proposals).mul(1.0 - masks) #            

            prop_std = proposal_std*torch.ones(original_proposals.shape).to(device)
            means = torch.zeros(sample_size, 784).cuda()
            acceptances = torch.zeros([len(original_proposals), 1]).to(device)
            total_proposals = torch.zeros(acceptances.shape).to(device)
            gibbs_prob = 1.0
            resample_prob = 0.5
            initialization_time = 5
            chain_length = 2000
            resample_rule = lambda i : (0.5)*(((i-1) % chain_length) < 500) + (0.5)*(((i-1) % chain_length) >= 500)

            tol = 0.01
            current_mask = torch.empty(original_proposals.shape).cuda()
            resample_mask = torch.empty([len(original_proposals), 1]).cuda()
            normal_samples = torch.empty(original_proposals.shape).cuda()
            acceptance_samples = torch.empty([len(original_proposals), 1]).cuda()
            perturbations = torch.empty(original_proposals.shape).cuda()
            cuda_scaling = scaling.cuda()
            reconstruction_nll = []
            reconstructions = []
            for i in range(0, 4001):
                starts = flow.f(projected_end)[0]
                latent_means = 0*flow.f(inputs.mul(masks))[0]
                current_mask.bernoulli_(gibbs_prob)
                resample_prob = resample_rule(i)
                resample_mask.bernoulli_(1.0- resample_prob)
                normal_samples.normal_(mean=0.0, std=1.0)
                new_proposals = (original_proposals + normal_samples.mul(prop_std).mul(current_mask)).mul(resample_mask)
                new_proposals += (latent_means + normal_samples.mul(sample_std)).mul(1.0 - resample_mask)
                proposal = flow.g(new_proposals)
                perturbations.uniform_()
                perturbations = ((perturbations - 0.5)/255.0).div(cuda_scaling)
                bayes_mod = (((flow.g(original_proposals)-inputs + perturbations).mul(masks)**2).sum(dim=1)/2.0 - ((proposal-inputs+perturbations).mul(masks)**2).sum(dim=1)/2.0)/(auxiliary_std**2)

                proposal = proposal.mul(1.0 - masks) + (inputs+perturbations).mul(masks)
                projected_end = projected_end.mul(1.0 - masks) + (inputs+perturbations).mul(masks)
                acceptance_prob = torch.exp(bayes_mod.unsqueeze(1) + flow.log_prob(proposal).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
              
                acceptance_samples.uniform_()
                
                current_acceptances=(acceptance_samples < acceptance_prob).float()
                acceptances += current_acceptances.mul(resample_mask)
                total_proposals += resample_mask.mul(resample_mask)
                
                projected_end = proposal.mul(current_acceptances) + projected_end.mul(1.0 - current_acceptances)
                projected_end = projected_end.mul(1.0 - masks) + inputs.mul(masks)
                original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)

                if i % (chain_length/progression_len) == 0 and i <= chain_length:
                    recon = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    recon = torch.max(recon, clamp_min)
                    recon = torch.min(recon, clamp_max)
                    num_progression = progression.shape[0]
                    outputs = utils.prepare_data(
                       recon.cpu()[0:num_progression], dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)
                    progression[:, int(i/(chain_length/progression_len)), :, :] = outputs.squeeze()
                if i % chain_length == 0 and i!=0:
                    recon = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    recon = torch.max(recon, clamp_min)
                    recon = torch.min(recon, clamp_max)
                    outputs = utils.prepare_data(
                       recon.cpu(), dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)
                    
                    for k in range(0, repetition_len):
                        reconstructions.append(recon[k*sample_size: (k+1)*sample_size].clone().cuda())
                        reconstruction_nll.append(flow.log_prob(recon[k*sample_size: (k+1)*sample_size]).unsqueeze(1))
                    stacked_prob = torch.stack(reconstruction_nll)
                    stack_max = stacked_prob.max(0)[0]
                    sum_factor = torch.exp(torch.stack([entry - stack_max for entry in reconstruction_nll])).sum(0)
                    weighted_mean = torch.stack([reconstructions[index].mul(torch.exp(reconstruction_nll[index] - stack_max)).div(sum_factor) for index in range(0, len(reconstructions))]).sum(0)

                    means += recon.view(repetition_len,sample_size, 784).mean(0).squeeze().view(sample_size, 784)
                    
                    print(float(((means.cuda()/(float(i/chain_length))-inputs[0:sample_size]).mul(scaling.cuda()).mul(1.0-masks[0:sample_size])**2).sum(1).div((1.0-masks[0:sample_size]).sum(1)).sqrt().sum()/float(len(masks[0:sample_size]))))
                    print(float(((weighted_mean.cuda()-inputs[0:sample_size]).mul(scaling.cuda()).mul(1.0-masks[0:sample_size])**2).sum(1).div((1.0-masks[0:sample_size]).sum(1)).sqrt().sum()/float(len(masks[0:sample_size]))))

                    outputs = utils.prepare_data(
                       means.cpu()/(float(i/chain_length)), dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)

                    torchvision.utils.save_image(torchvision.utils.make_grid(outputs, nrow=3),
                        './test/' + 'samplemeans%d.png' % i)
                    outputs = utils.prepare_data(
                       weighted_mean.cpu(), dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)

                    torchvision.utils.save_image(torchvision.utils.make_grid(outputs, nrow=3),
                        './test/' + 'weightedmeans%d.png' % i)

                    original_proposals = torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda()
                    projected_end = inputs.mul(masks) + flow.g(original_proposals).mul(1.0 - masks)
                    acceptances = torch.zeros(acceptances.shape).to(device)
                    total_proposals = torch.zeros(acceptances.shape).to(device)
                if i % 100 == 0:
                    recon = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    recon = torch.max(recon, clamp_min)
                    recon = torch.min(recon, clamp_max)
                    outputs = utils.prepare_data(
                       recon.cpu(), dataset, zca=zca, mean=mean,rescale=scaling, reverse=True)

                    torchvision.utils.save_image(torchvision.utils.make_grid(outputs[0:sample_size], nrow=3),
                        './test/' + 'sampleoutput%d.png' % i)
                    print(i, (-flow.log_prob(inputs) + flow.log_prob(projected_end)).mean(), float(((projected_end-inputs).mul(scaling.cuda()).mul(1.0-masks)**2).sum(1).div((1.0-masks).sum(1)).sqrt().sum()/float(len(masks))))



                    
            break
        torchvision.utils.save_image(torchvision.utils.make_grid(progression.view(num_progression*(progression_len+1), 28, 28).unsqueeze(1), nrow=(progression_len+1)), './test/' + 'progression.png')
    print('Finished Testing')


if __name__ == '__main__':
    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--checkpoint',
                        help='checkpointfile.',
                        type=str,
                        default='./models/mnist/mnist_bs200_logistic_cp4_md1000_hd5_full_iter1000.tar')
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=6)
    args = parser.parse_args()
    main(args)
