import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np
import utils

class MaskedMNIST(Dataset):
    def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0, train=True, device=torch.device("cuda:0"), num_copies=1, override=False, mean=None, std=None, clamp_min=None, clamp_max=None):
        self.rnd = np.random.RandomState(random_seed)
        self.image_size = image_size
        self.train=train
        if image_size == 28:
            self.data = datasets.MNIST(
                data_dir, train=self.train, download=True,
                transform=transforms.ToTensor())
        else:
            self.data = datasets.MNIST(
                data_dir, train=self.train, download=True,
                transform=transforms.Compose([
                    transforms.Resize(image_size), transforms.ToTensor()]))
        self.data.data = self.data.data.float().repeat(num_copies, 1,1).to(device)/255.0
        self.data.targets.repeat(num_copies).to(device)
        self.generate_masks()
        self.mask = self.mask.float().to(device)
        self.latents = torch.zeros(len(self.data.data))
        self.projected_ends = torch.zeros(len(self.data.data))
        self.override = override
        self.stdev = std
        self.means = mean
        self.observed_values = None
        self.min= clamp_min
        self.max = clamp_max

    def init_latents(self, flow, batch_size=200):
        self.latents = torch.empty(len(self.data.data), 28*28).normal_(mean=0.0, std=1.0).cuda()#torch.zeros(self.data.data.shape).cuda()#
        self.projected_ends = torch.zeros(self.data.data.shape).cuda()
        self.projected_ends = self.data.data.mul(1.0 - self.mask)
        vis_prob = self.mask.mean(dim=0)
        
        observed_mean = (self.data.data.mul(self.mask)).mean(dim=0).div(vis_prob)
        if not self.override:
            self.means = observed_mean.view(1, 784).cpu()
            self.stdev = (((self.data.data - observed_mean).mul(self.mask)**2).mean(dim=0)).div(vis_prob)**0.5 + 1.0/(255.0 + (12.0**0.5))
            self.stdev = self.stdev.view(1,784).cpu()
        with torch.no_grad():
            for ndx in range(0, len(self.data.data), batch_size):
                l=len(self.data.data)
                self.latents[ndx:min(ndx + batch_size, l)] = 0*self.latents[ndx:min(ndx + batch_size, l)]#.mul(self.stdev.cuda())
                original_proposals = flow.f(self.latents[ndx:min(ndx + batch_size, l)])[0]
                dataset = 'mnist'
                zca = None
                self.data.data[ndx:min(ndx + batch_size, l)] = utils.prepare_data(self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]),1, 28,28).cpu(), dataset, zca=zca, mean=self.means, rescale=self.stdev).cuda().view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28,28)
                inputs = self.data.data[ndx:min(ndx + batch_size, l)]
                resample_mask = torch.empty([len(inputs), 1]).bernoulli(0.5).cuda()
                masks = self.mask[ndx:min(ndx + batch_size, l)]
                original_proposals = original_proposals.view(len(original_proposals), 28*28)

                #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                masks = masks.view(len(masks), 28*28)
                projected_end = flow.g(original_proposals).mul(1.0 - masks) + self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks)
                self.projected_ends[ndx:min(ndx + batch_size, l)] = projected_end.view(len(projected_end), 28,28)
                projected_end = None
                if ndx == 0 and not self.override:
                    self.min = (self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks) + 100*(1.0-masks)).min(dim=0)[0]
                    self.max = (self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks)).max(dim=0)[0]
                elif not self.override:
                    self.min = torch.min(self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks) + 100*(1.0-masks), self.min).min(dim=0)[0]
                    self.max = torch.max(self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks), self.max).max(dim=0)[0]

    def reset_latents(self, flow, batch_size=200, model_reset=False):
        sample_std = 1.814
        self.latents = torch.empty(len(self.data.data), 28*28).normal_(mean=0.0, std=1.0).cuda()#torch.zeros(self.data.data.shape).cuda()#
        self.projected_ends = torch.zeros(self.data.data.shape).cuda()
        with torch.no_grad():
            for ndx in range(0, len(self.data.data), batch_size):
                l=len(self.data.data)
                if model_reset:
                    self.latents[ndx:min(ndx + batch_size, l)] = flow.g(torch.empty((self.latents[ndx:min(ndx + batch_size, l)]).shape).normal_(mean=0.0, std=1.0).mul(sample_std).cuda())
                else:
                    self.latents[ndx:min(ndx + batch_size, l)] = self.latents[ndx:min(ndx + batch_size, l)]#.mul(self.stdev.cuda())
                #self.latents[ndx:min(ndx + batch_size, l)] = torch.max(self.latents[ndx:min(ndx + batch_size, l)], self.min)
                #self.latents[ndx:min(ndx + batch_size, l)] = torch.min(self.latents[ndx:min(ndx + batch_size, l)], self.max)
                original_proposals = flow.f(self.latents[ndx:min(ndx + batch_size, l)])[0]
                masks = self.mask[ndx:min(ndx + batch_size, l)]
                original_proposals = original_proposals.view(len(original_proposals), 28*28)

                #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                masks = masks.view(len(masks), 28*28)
                projected_end = flow.g(original_proposals).mul(1.0 - masks) + self.data.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), 28*28).mul(masks)
                self.projected_ends[ndx:min(ndx + batch_size, l)] = projected_end.view(len(projected_end), 28,28)
                projected_end = None

    def resample_latents(self, flow, batch_size=200):
        sample_std = 1.0
        with torch.no_grad():
            for ndx in range(0, len(self.data.data), batch_size):
                l=len(self.data.data)
                original_proposals = self.latents[ndx:min(ndx + batch_size, l)]
                original_proposals = original_proposals.view(len(original_proposals), 28*28)
                new_proposals = flow.sample(len(original_proposals))#torch.empty(original_proposals.shape).normal_(mean=0.0, std=sample_std).cuda()#
                inputs = self.data.data[ndx:min(ndx + batch_size, l)]
                masks = self.mask[ndx:min(ndx + batch_size, l)]
                

                #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                inputs = inputs.view(len(inputs), 28*28)
                masks = masks.view(len(masks), 28*28)
                proposal = flow.g(new_proposals)
                problems = torch.max(proposal, 1)[0]
                problems = problems.unsqueeze(1)
                new_proposals[proposal != proposal] = original_proposals[proposal != proposal]                
                original_proposals = new_proposals
                self.latents[ndx:min(ndx + batch_size, l)] = new_proposals.view(len(original_proposals), 28, 28)
                self.projected_ends[ndx:min(ndx + batch_size, l)] = (flow.g(new_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)

    def resample_latents_index(self, flow,indices, batch_size=200):
        sample_std = 1.814
        with torch.no_grad():
            original_proposals = self.latents[indices]
            original_proposals = original_proposals.view(len(original_proposals), 28*28)
            new_proposals = flow.sample(len(original_proposals))#torch.empty(original_proposals.shape).normal_(mean=0.0, std=sample_std).cuda()#
            inputs = self.data.data[indices]
            masks = self.mask[indices]
            

            #resample_mask = resample_mask.view(len(resample_mask), 28*28)
            inputs = inputs.view(len(inputs), 28*28)
            masks = masks.view(len(masks), 28*28)
            proposal = flow.g(new_proposals)
            problems = torch.max(proposal, 1)[0]
            problems = problems.unsqueeze(1)
            new_proposals[proposal != proposal] = original_proposals[proposal != proposal]                
            original_proposals = new_proposals
            self.latents[indices] = new_proposals.view(len(original_proposals), 28, 28)
            self.projected_ends[indices] = (flow.g(new_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)

    def get_latents(self, flow, batch_size=10000, step=True, num_steps=5, sample_std = 1.814):
        with torch.no_grad():
            acceptances = 0
            tries = 0
            resample_prob=0.5
            prop_std = 0.05
            gibbs_prob = 1.0
            if step:
                accepted_indices = range(0, len(self.data.data))
                changing_indices = []
                current_mask_holder = torch.empty(batch_size, 28*28).cuda()
                resample_mask_holder = torch.empty([batch_size, 1]).cuda()
                acceptance_samples = torch.empty([batch_size, 1]).cuda()
                normal_sample_holder = torch.empty(batch_size, 28*28).cuda()
                perturbations = torch.empty(batch_size, 28*28).cuda()
                for sample_idx in range(0, num_steps):
                    #if sample_idx != 0:
                    #    accepted_indices = changing_indices
                    for repetition in range(0, int(len(self.data.data)/len(accepted_indices))):
                        for ndx in range(0, len(accepted_indices), batch_size):
                            l=len(accepted_indices)
                            batch_length = min(ndx + batch_size, l) - ndx
                            affected_indices = accepted_indices[ndx:min(ndx + batch_size, l)]
                            original_proposals = flow.f(self.latents[affected_indices])[0]
                            inputs = self.data.data[affected_indices]
                            resample_mask_holder.bernoulli_(1.0 - resample_prob)
                            resample_mask = resample_mask_holder[0:batch_length]
                            masks = self.mask[affected_indices]
                            original_proposals = original_proposals.view(len(original_proposals), 28*28)

                            #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                            masks = masks.view(len(masks), 28*28)
                            inputs = inputs.view(len(inputs), 28*28)
                            projected_end = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                            current_mask_holder.bernoulli_(gibbs_prob)
                            current_mask = current_mask_holder[0:batch_length]
                            normal_sample_holder.normal_(mean=0.0, std=1.0)
                            normal_samples = normal_sample_holder[0:batch_length]
                            new_proposals = (original_proposals + normal_samples.mul(prop_std)).mul(current_mask).mul(resample_mask)

                            
                            new_proposals += normal_samples.mul(sample_std).mul(1.0 - resample_mask)
                            proposal = flow.g(new_proposals)
                            perturbations.uniform_()
                            perturbations = ((perturbations - 0.5)/255.0).div(self.stdev.cuda())
                            #bayes_mod = (1e6)*(((flow.g(original_proposals)-inputs + perturbations).mul(self.stdev.cuda()).mul(masks)**2).sum(dim=1)/2.0 - ((proposal-inputs + perturbations).mul(self.stdev.cuda()).mul(masks)**2).sum(dim=1)/2.0)
                            bayes_mod = (1e6)*(((flow.g(original_proposals)-inputs).mul(masks)**2).sum(dim=1)/2.0 - ((proposal-inputs).mul(masks)**2).sum(dim=1)/2.0)
                            proposal = proposal.mul(1.0 - masks) + (inputs + perturbations).mul(masks)
                            projected_end = projected_end.mul(1.0-masks) + (inputs + perturbations).mul(masks)
                            #acceptance_prob = torch.exp(bayes_mod.unsqueeze(1) + flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(8).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(8).sum(dim=1).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                            acceptance_samples.uniform_()

                            acceptance_prob = torch.exp(bayes_mod.unsqueeze(1) + flow.log_prob(proposal).unsqueeze(1)  - flow.log_prob(projected_end).unsqueeze(1)  + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                            
                            current_acceptances=(acceptance_samples[0:batch_length] < acceptance_prob).float()
                            acceptances += float(current_acceptances.sum())
                            tries += float(len(current_acceptances))
                            if sample_idx ==0:
                                changing_indices += [affected_indices[index] for index in range(0,batch_length) if current_acceptances[index]]
                            original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)
                            self.latents[affected_indices] = flow.g(original_proposals)
                            self.projected_ends[affected_indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)

        return acceptances/tries



    def get_latents_indexed(self, flow, indices, step=True, num_steps=5):
        with torch.no_grad():
            prop_std = 0.01
            sample_std = 1.841
            if step:
                for sample_idx in range(0, num_steps):
                    resample_prob = 0.0
                    if sample_idx != 0:
                        resample_prob = 0.5
                    l=len(self.data.data)
                    original_proposals = self.latents[indices]
                    inputs = self.data.data[indices]
                    resample_mask = torch.empty([len(inputs), 1]).bernoulli(resample_prob).cuda()
                    masks = self.mask[indices]
                    original_proposals = original_proposals.view(len(original_proposals), 28*28)

                    #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                    masks = masks.view(len(masks), 28*28)
                    inputs = inputs.view(len(inputs), 28*28)
                    projected_end = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    new_proposals = (original_proposals + torch.empty(inputs.shape).normal_(mean=0.0, std=1.0).cuda().mul(prop_std)).mul(resample_mask)

                    
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)
                    proposal = flow.g(new_proposals)
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    acceptance_prob = torch.exp(flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                    '''
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)#flow.sample(len(masks)).mul(1.0 - resample_mask)#
                    proposal = flow.g(new_proposals)
                    problems = torch.max(proposal, 1)[0]
                    problems = problems.unsqueeze(1)
                    new_proposals[proposal != proposal] = original_proposals[proposal != proposal]
                    proposal[proposal != proposal] = projected_end[proposal != proposal]
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    sample_probs = flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + flow.prior.log_prob(original_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask) - flow.prior.log_prob(new_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask)
                    sample_probs[sample_probs != sample_probs] = -1e10
                    #print(sample_probs[sample_probs != sample_probs])
                    #print(torch.max(sample_probs), torch.max(flow.log_prob(proposal).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1)))
                    
                    acceptance_prob = torch.exp(sample_probs)
                    '''
                    
                    current_acceptances=(torch.empty(acceptance_prob.shape).uniform_().cuda() < acceptance_prob).float()
                    if sample_idx ==0:
                        all_rejections = (current_acceptances < 1).float()
                    else:
                        all_rejections = ((1 - all_rejections + current_acceptances) < 1).float()
                    original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)
                    self.latents[indices] = original_proposals.view(len(original_proposals), 28, 28)
                    self.projected_ends[indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)

                
                '''
                new_proposals = flow.sample(len(masks))
                resample_mask = torch.empty([len(inputs), 1]).bernoulli(0.9).cuda()
                proposal = flow.g(new_proposals)
                problems = torch.max(proposal, 1)[0]
                problems = problems.unsqueeze(1)
                new_proposals[proposal != proposal] = original_proposals[proposal != proposal]                
                original_proposals = new_proposals.mul(1.0-resample_mask) + original_proposals.mul(resample_mask)
                self.latents[indices] = original_proposals.view(len(original_proposals), 28, 28)
                self.projected_ends[indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)
                '''

    def get_latents_full(self, flow, step=True, num_steps=5):
        with torch.no_grad():
            prop_std = 0.01
            sample_std = 1.841
            if step:
                for sample_idx in range(0, num_steps):
                    resample_prob = 0.0
                    if sample_idx != 0:
                        resample_prob = 0.5
                    l=len(self.data.data)
                    original_proposals = self.latents
                    inputs = self.data.data
                    resample_mask = torch.empty([len(inputs), 1]).bernoulli(resample_prob).cuda()
                    masks = self.mask
                    original_proposals = original_proposals.view(len(original_proposals), 28*28)

                    #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                    masks = masks.view(len(masks), 28*28)
                    inputs = inputs.view(len(inputs), 28*28)
                    projected_end = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    new_proposals = (original_proposals + torch.empty(inputs.shape).normal_(mean=0.0, std=1.0).cuda().mul(prop_std)).mul(resample_mask)

                    
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)
                    proposal = flow.g(new_proposals)
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    acceptance_prob = torch.exp(flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                    '''
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)#flow.sample(len(masks)).mul(1.0 - resample_mask)#
                    proposal = flow.g(new_proposals)
                    problems = torch.max(proposal, 1)[0]
                    problems = problems.unsqueeze(1)
                    new_proposals[proposal != proposal] = original_proposals[proposal != proposal]
                    proposal[proposal != proposal] = projected_end[proposal != proposal]
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    sample_probs = flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + flow.prior.log_prob(original_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask) - flow.prior.log_prob(new_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask)
                    sample_probs[sample_probs != sample_probs] = -1e10
                    #print(sample_probs[sample_probs != sample_probs])
                    #print(torch.max(sample_probs), torch.max(flow.log_prob(proposal).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1)))
                    
                    acceptance_prob = torch.exp(sample_probs)
                    '''
                    
                    current_acceptances=(torch.empty(acceptance_prob.shape).uniform_().cuda() < acceptance_prob).float()
                    if sample_idx ==0:
                        all_rejections = (current_acceptances < 1).float()
                    else:
                        all_rejections = ((1 - all_rejections + current_acceptances) < 1).float()
                    original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)
                    self.latents = original_proposals.view(len(original_proposals), 28, 28)
                    self.projected_ends = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)


    def get_binarized_latents_indexed(self, flow, indices, step=True, num_steps=5):
        with torch.no_grad():
            prop_std = 0.01
            sample_std = 1.814
            if step:
                for sample_idx in range(0, num_steps):
                    resample_prob = 0.0
                    if sample_idx != 0:
                        resample_prob = 0.5
                    l=len(self.data.data)
                    original_proposals = self.latents[indices]
                    inputs = self.data.data[indices]
                    resample_mask = torch.empty([len(inputs), 1]).bernoulli(resample_prob).cuda()
                    masks = self.mask[indices]
                    original_proposals = original_proposals.view(len(original_proposals), 28*28)

                    #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                    masks = masks.view(len(masks), 28*28)
                    inputs = inputs.view(len(inputs), 28*28)
                    projected_end = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                    new_proposals = (original_proposals + torch.empty(inputs.shape).normal_(mean=0.0, std=1.0).cuda().mul(prop_std)).mul(resample_mask)

                    
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)
                    proposal = flow.g(new_proposals)

                    highs = ((255.0 + torch.distributions.Uniform(0., 1.).sample(proposal.size()))/256.0 - self.means).div(self.stdev).cuda()
                    lows = ((torch.distributions.Uniform(0., 1.).sample(proposal.size()))/256.0 - self.means).div(self.stdev).cuda()
                    diffs = (proposal - highs)**2 - (proposal - lows)**2
                    proposal = highs.mul((diffs < 0).float()) + lows.mul((diffs > 0).float())
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    acceptance_prob = torch.exp(flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                    '''
                    new_proposals += torch.empty(inputs.shape).normal_(mean=0.0, std=sample_std).cuda().mul(1.0 - resample_mask)#flow.sample(len(masks)).mul(1.0 - resample_mask)#
                    proposal = flow.g(new_proposals)
                    problems = torch.max(proposal, 1)[0]
                    problems = problems.unsqueeze(1)
                    new_proposals[proposal != proposal] = original_proposals[proposal != proposal]
                    proposal[proposal != proposal] = projected_end[proposal != proposal]
                    proposal = proposal.mul(1.0 - masks) + inputs.mul(masks)
                    sample_probs = flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(2).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(2).sum(dim=1).unsqueeze(1) + flow.prior.log_prob(original_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask) - flow.prior.log_prob(new_proposals).sum(1).unsqueeze(1).mul(1.0 - resample_mask)
                    sample_probs[sample_probs != sample_probs] = -1e10
                    #print(sample_probs[sample_probs != sample_probs])
                    #print(torch.max(sample_probs), torch.max(flow.log_prob(proposal).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1)))
                    
                    acceptance_prob = torch.exp(sample_probs)
                    '''
                    current_acceptances=(torch.empty(acceptance_prob.shape).uniform_().cuda() < acceptance_prob).float()
                    if sample_idx ==0:
                        all_rejections = (current_acceptances < 1).float()
                    else:
                        all_rejections = ((1 - all_rejections + current_acceptances) < 1).float()
                    original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)
                    self.latents[indices] = original_proposals.view(len(original_proposals), 28, 28)
                    self.projected_ends[indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)

                
                '''
                new_proposals = flow.sample(len(masks))
                resample_mask = torch.empty([len(inputs), 1]).bernoulli(0.9).cuda()
                proposal = flow.g(new_proposals)
                problems = torch.max(proposal, 1)[0]
                problems = problems.unsqueeze(1)
                new_proposals[proposal != proposal] = original_proposals[proposal != proposal]                
                original_proposals = new_proposals.mul(1.0-resample_mask) + original_proposals.mul(resample_mask)
                self.latents[indices] = original_proposals.view(len(original_proposals), 28, 28)
                self.projected_ends[indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), 28, 28)
                '''                                  
                    
            

    def __getitem__(self, index):
        return  self.data.data[index], self.mask[index],  self.data.targets[index],self.projected_ends[index],self.latents[index],  index

    def __len__(self):
        return len(self.data)

    def generate_masks(self):
        raise NotImplementedError


class BlockMaskedMNIST(MaskedMNIST):
    def __init__(self, block_len=None, *args, **kwargs):
        self.block_len = block_len
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        d0_len = d1_len = self.image_size
        d0_min_len = 7
        d0_max_len = d0_len - d0_min_len
        d1_min_len = 7
        d1_max_len = d1_len - d1_min_len

        n_masks = len(self)
        self.mask = torch.zeros((n_masks, d0_len, d1_len), dtype=torch.uint8)
        self.mask_info = [None] * n_masks
        for i in range(n_masks):
            if self.block_len is None:
                d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
                d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
            else:
                d0_mask_len = d1_mask_len = self.block_len

            d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
            d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)

            mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
            mask[d0_start:(d0_start + d0_mask_len),
                 d1_start:(d1_start + d1_mask_len)] = 1
            self.mask[i] = mask
            self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len


class IndepMaskedMNIST(MaskedMNIST):
    def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
        self.prob = obs_prob
        self.prob_high = obs_prob_high
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        prob = self.prob
        prob_high = self.prob_high
        n_masks = len(self)
        self.mask = torch.ByteTensor(n_masks, imsize, imsize)
        if prob == 1.0:
            self.mask.fill_(1)
        else:
            for i in range(n_masks):
                if prob_high is None:
                    p = prob
                else:
                    p = self.rnd.uniform(prob, prob_high)
                self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)


class ShadowMaskedMNIST(MaskedMNIST):
    def __init__(self, depth=0.89, *args, **kwargs):
        self.depth = depth
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        depth = self.depth
        n_masks = len(self)
        self.mask = torch.ByteTensor(n_masks, imsize, imsize)
        for i in range(n_masks):
            image, label = self.data[i]
            image = image.view(imsize*imsize)
            starting_direction = self.rnd.randint(1, 5)
            starting_point = (imsize*imsize - 1) * int(starting_direction >2)
            depth_incr = (1 - 2*int(starting_direction >2))*int(imsize**(starting_direction %2))
            view_incr = (1 - 2*int(starting_direction >2))*int(imsize**((starting_direction + 1) %2))
            mask = torch.ByteTensor(imsize*imsize)
            for view_index in range(0, imsize):
                blocked = 1
                for depth_index in range(0, imsize):
                    mask[starting_point + view_incr*view_index + depth_incr*depth_index] = blocked
                    if(image[starting_point + view_incr*view_index + depth_incr*depth_index] > 0.89):
                        blocked = 0
            self.mask[i] = mask.view(imsize, imsize)

class PatchMaskedMNIST(MaskedMNIST):
    def __init__(self,num_patches=27, *args, **kwargs):
        self.num_patches = num_patches
        super().__init__(*args, **kwargs)
        

    def generate_masks(self):
        imsize = self.image_size
        n_masks = len(self)
        self.mask = torch.ByteTensor(n_masks, imsize, imsize)
        for i in range(n_masks):
            num_boxes = self.num_patches
            widths = self.rnd.randint(1, 10, num_boxes)
            lengths = 25/widths
            mask = torch.ByteTensor(imsize*imsize).bernoulli_(1)
            for index, entry in enumerate(widths):
                startingpoint = self.rnd.randint(1, 28*28)
                layer = 1
                while layer <= lengths[index]:
                    mask[startingpoint + 28*(layer-1): startingpoint + entry + 28*(layer-1)] = 0
                    layer += 1
            self.mask[i] = mask.view(imsize, imsize)


