import json
import math
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

from datasets import get_CIFAR10, get_SVHN, postprocess
from model import Glow

device = torch.device("cuda")

output_folder = './glow/'
model_name = 'glow_affine_coupling.pt'

proposal_std = 0.01
auxiliary_std = 1e-3
num_proposals = 25000

with open(output_folder + 'hparams.json') as json_file:  
    hparams = json.load(json_file)

image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'], hparams['dataroot'], True)

model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
             hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
             hparams['learn_top'], hparams['y_condition'])

model.load_state_dict(torch.load(output_folder + model_name))
model.set_actnorm_init()

model = model.to(device)

model = model.eval()

def sample(model):
    with torch.no_grad():
        if hparams['y_condition']:
            y = torch.eye(num_classes)
            y = y.repeat(batch_size // num_classes + 1)
            y = y[:32, :].to(device) # number hardcoded in model for now
        else:
            y = None

        images = model(y_onehot=y, temperature=1, reverse=True)

    return images.cpu()

def run_completions(model, data_inputs, masks, projected_end, prop_std,aux_std, steps=20):
    sample_temp = 1.0
    gibbs_prob = 1.0
    with torch.no_grad():
        original_completions = projected_end.clone()
        new_original_completions = original_completions.mul(1.0 - masks) + data_inputs.mul(masks)
        perturbations = torch.empty(data_inputs.shape).cuda()
        total_acceptances = torch.zeros(len(data_inputs),1,1,1).cuda()
        gibbs_masks = [torch.empty(len(data_inputs), 6, 16, 16).cuda(), torch.empty(len(data_inputs), 12, 8, 8).cuda(), torch.empty(len(data_inputs), 48, 4, 4).cuda()]
        new_proposals = [torch.empty(len(data_inputs), 6, 16, 16).cuda(), torch.empty(len(data_inputs), 12, 8, 8).cuda(), torch.empty(len(data_inputs), 48, 4, 4).cuda()]
        acceptances = torch.empty(len(data_inputs),1,1,1).cuda()
        for i in range(0,steps):
            inputs = (256.0*data_inputs + 0*perturbations.uniform_())/256.0
            full_original_proposals, old_resample_nll = model.full_normal_flow(projected_end, None, relative=False)
            original_completions = model.flow.full_decode(full_original_proposals)
            new_original_completions = original_completions.mul(1.0 - masks) + inputs.mul(masks)
            old_comp_latents, original_nll = model.full_normal_flow(new_original_completions, None)
            current_index = 0
            for j in range(0, 3):
                new_proposals[j].normal_(mean=0.0, std=1.0)
                new_proposals[j] = new_proposals[j].mul(gibbs_masks[j].bernoulli_(gibbs_prob)).mul(prop_std)
                new_proposals[j] += full_original_proposals[j]
            new_projected_end = model.flow.full_decode(new_proposals, None, relative=False)
            
            new_proposals2, new_resample_nll2 = model.full_normal_flow(new_projected_end, None)
            new_completions = new_projected_end.mul(1.0 - masks) + inputs.mul(masks)
            new_comp_latents, new_nll = model.full_normal_flow(new_completions, None)
            bayes_prob = (((projected_end - inputs).mul(masks)**2)/2.0 - ((new_projected_end - inputs).mul(masks)**2)/2.0).view(len(inputs), 3*32*32).sum(dim=1)/(aux_std**2)
            acceptance_prob = torch.exp(bayes_prob + new_nll  - original_nll).unsqueeze(1).unsqueeze(2).unsqueeze(3)
            acceptance_prob[acceptance_prob != acceptance_prob] = 0.0
            acceptance_prob[new_nll.unsqueeze(1).unsqueeze(2).unsqueeze(3) != new_nll.unsqueeze(1).unsqueeze(2).unsqueeze(3)] = 0.0
            acceptances.uniform_()
            acceptances = (acceptances <= acceptance_prob).float()
            acceptances[acceptances != acceptances] = 0.0
            acceptances[new_nll != new_nll] = 0.0
            new_projected_end[new_projected_end != new_projected_end] = 0.0
            new_completions[new_completions != new_completions] = 0.0
            total_acceptances += acceptances
            
            print(i, float(acceptances.sum()), float(original_nll.mean()), float((old_comp_latents[2].view(len(inputs), 48*4*4)**2).sum(dim=1).mean()/2.0), float(((projected_end - inputs).mul(masks)**2/2.0).mean()), float(torch.max(new_projected_end)), float(torch.min(new_projected_end)))
            projected_end[acceptances.squeeze().bool()] = torch.clamp(new_projected_end[acceptances.squeeze().bool()],-0.5,0.5)#new_projected_end[acceptances.squeeze().bool()]# 
    return new_original_completions, projected_end



def conv_mask(batch, mask, iterations=2, full=False):
    num_samples = len(batch)
    with torch.no_grad():
        weight =  0.176765 * torch.ones(1,1,3,3).cuda()
        weight[0][0][1][1] = 1.0
        weight[0][0][0][0] = 0.073235
        weight[0][0][0][2] = 0.073235
        weight[0][0][2][0] = 0.073235
        weight[0][0][2][2] = 0.073235
        weight /= 2.0
        inpaint = (batch.mul(mask) +batch.mul(1.0 - mask)).view(num_samples*3,1, 32, 32)
        for i in range(0, iterations):
            if not full:
                inpaint = batch.mul(mask).view(num_samples*3,1, 32, 32) + inpaint.mul(1.0- mask.view(num_samples*3,1, 32, 32))
            inpaint = torch.nn.functional.conv2d(inpaint, weight, padding=1)
    if not full:
        inpaint = batch.mul(mask).view(num_samples*3,1, 32, 32) + inpaint.mul(1.0 - mask.view(num_samples*3,1, 32, 32))
    return inpaint.view(num_samples,3, 32, 32)


import numpy as np
test_loader = torch.utils.data.DataLoader(test_cifar, batch_size=24, shuffle=False, drop_last=False)
inputs, classes = next(iter(test_loader))

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1,1,1)
#torch.manual_seed(897)
with torch.no_grad():
    test, test_nll =  model.full_normal_flow(inputs.cuda(), None)
    print(test_nll.mean())
    
    masks = torch.ones(len(inputs), 1, 32, 32).cuda()
    masks = masks.repeat(1,3,1,1)
    masks[:,:,8:24,8:24] = 0.0
    inputs = inputs.cuda()
    projected_end = inputs.mul(masks)
    
    for i in range(0, len(projected_end)):
        pixel_list = []
        for j in range(0, 32):
            for k in range(0, 32):
                if masks[i][0][j][k] == 1.0:
                    pixel_list.append([j,k])
        for j in range(0, 32):
            for k in range(0, 32):
                indices = pixel_list[np.random.choice(range(0, len(pixel_list)))]
                projected_end[i][0][j][k] = projected_end[i][0][indices[0]][indices[1]]
                projected_end[i][1][j][k] = projected_end[i][1][indices[0]][indices[1]]
                projected_end[i][2][j][k] = projected_end[i][2][indices[0]][indices[1]]
                    
    
    projected_end = conv_mask(projected_end, masks, iterations=2, full=True).cuda()   
    full_latents, nll =  model.full_normal_flow(projected_end, None)
    #projected_end  = model.flow.decode(full_latents[2], temperature=1.0)

    
    
    
    projected_end = torch.clamp(projected_end, -0.5, 0.5)
    original_completions,projected_end = run_completions(model, inputs, masks, projected_end, 0.005*torch.ones(len(inputs), 1, 1, 1).cuda(), auxiliary_std, steps=-1)

samples = model.reverse_flow(None, None, 1.0, len(inputs))
grid = make_grid(postprocess(samples.cpu()), nrow=6).permute(1,2,0)
ax.imshow(grid)    
plt.savefig('cifar10_samples.png')

grid = make_grid(postprocess(original_completions.cpu()), nrow=6).permute(1,2,0)
ax.imshow(grid)    
plt.savefig('start.png')


grid = make_grid(postprocess(inputs.cpu()), nrow=6).permute(1,2,0)
ax.imshow(grid)    
plt.savefig('inputs.png')



import time
prop_std = proposal_std*torch.ones(len(inputs), 1, 1, 1).cuda()
for i in range(0, int(num_proposals/100)):
    start = time.time()
    original_completions,projected_end = run_completions(model, inputs, masks, projected_end, prop_std, auxiliary_std, steps=100)
    grid = make_grid(postprocess(original_completions.cpu()), nrow=6).permute(1,2,0)
    ax.imshow(grid)   
    plt.savefig('cifar10_%i.png' % (100*(i)))
    print(time.time() - start)



