"""Train script.

Usage:
    infer_celeba.py <hparams> <dataset_root> <z_dir>
"""
import os
import cv2
import random
import torch
import vision
import numpy as np
from docopt import docopt
from torchvision import transforms
from glow.builder import build
from glow.config import JsonConfig
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


def select_index(name, l, r, description=None):
    index = None
    while index is None:
        print("Select {} with index [{}, {}),"
              "or {} for random selection".format(name, l, r, l - 1))
        if description is not None:
            for i, d in enumerate(description):
                print("{}: {}".format(i, d))
        try:
            line = int(input().strip())
            if l - 1 <= line < r:
                index = line
                if index == l - 1:
                    index = random.randint(l, r - 1)
        except Exception:
            pass
    return index


def run_z(graph, z):
    graph.eval()
    x = graph(z=None, eps_std=1.0, reverse=True)#graph(z=torch.tensor([z]).cuda(), eps_std=0.3, reverse=True)
    img = x[0]#.permute(1, 2, 0).detach().cpu().numpy()
    #img = img[:, :, ::-1]
    #img = cv2.resize(img, (256, 256))
    return img


def save_images(images, names):
    if not os.path.exists("pictures/infer/"):
        os.makedirs("pictures/infer/")
    for img, name in zip(images, names):
        img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
        cv2.imwrite("pictures/infer/{}.png".format(name), img)
        cv2.imshow("img", img)
        cv2.waitKey()


if __name__ == "__main__":
    auxiliary_std = 1e-3
    proposal_std = 0.02
    num_proposals = 25000
    
    args = docopt(__doc__)
    hparams = args["<hparams>"]
    dataset_root = args["<dataset_root>"]
    z_dir = args["<z_dir>"]
    assert os.path.exists(dataset_root), (
        "Failed to find root dir `{}` of dataset.".format(dataset_root))
    assert os.path.exists(hparams), (
        "Failed to find hparams josn `{}`".format(hparams))
    if not os.path.exists(z_dir):
        print("Generate Z to {}".format(z_dir))
        os.makedirs(z_dir)
        generate_z = True
    else:
        print("Load Z from {}".format(z_dir))
        generate_z = False

    hparams = JsonConfig("hparams/celeba.json")
    dataset = vision.Datasets["celeba"]
    # set transform of dataset
    transform = transforms.Compose([
        transforms.CenterCrop(hparams.Data.center_crop),
        transforms.Resize(hparams.Data.resize),
        transforms.ToTensor()])
    # build
    graph = build(hparams, False)["graph"]
    dataset = dataset(dataset_root, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
    with torch.no_grad():
        # get Z
        if not generate_z:
            # try to load
            try:
                delta_Z = []
                for i in range(hparams.Glow.y_classes):
                    z = np.load(os.path.join(z_dir, "detla_z_{}.npy".format(i)))
                    delta_Z.append(z)
            except FileNotFoundError:
                # need to generate
                generate_z = True
                print("Failed to load {} Z".format(hparams.Glow.y_classes))
                quit()
        if generate_z:
            delta_Z = graph.generate_attr_deltaz(dataset)
            for i, z in enumerate(delta_Z):
                np.save(os.path.join(z_dir, "detla_z_{}.npy".format(i)), z)
            print("Finish generating")
        for x in dataloader:
            data = x.cuda()
            fig = plt.figure(figsize=(10,10))
            ax = fig.add_subplot(1,1,1)
            sample_std = 0.5
            projected_end = graph.reverse_flow(None, None, sample_std)
            print(projected_end.shape)
            samples = projected_end.repeat(2,1, 1, 1)
            for k in range(0, 2):
                samples[k*len(projected_end):(k+1)*len(projected_end),:,:,:] = graph.reverse_flow(None, None, sample_std)
            projected_end = graph.reverse_flow(None, None, sample_std)
            grid = make_grid(samples.cpu(), nrow=6).permute(1,2,0)
            ax.imshow(grid)
            plt.savefig('./samples.png')
            projected_end = projected_end[0:6]
            full_original_proposals, old_resample_nll, y_logits = graph.full_normal_flow(projected_end, None)
            projected_end = graph.flow.decode(full_original_proposals[2], eps_std=0.01)
            gibbs_projected_ends = projected_end.repeat(8,1,1,1)
            current_gibbs_fixed = projected_end.clone()
            data_masks = torch.ones(data.shape).cuda()
            masks = torch.ones(data.shape).cuda()
            focus_mask = torch.zeros(data.shape).cuda()
            data_masks[:,:,:, 32:64] = 0.0
            data_inputs = data.mul(data_masks) + projected_end.mul(1.0 - data_masks)
            inputs = data_inputs.clone()
            masks = data_masks

            results = 0*inputs.repeat(100, 1, 1, 1)
            new_proposals = [torch.empty(len(inputs), 6, 32, 32).cuda(), torch.empty(len(inputs), 12, 16, 16).cuda(), torch.empty(len(inputs), 48, 8, 8).cuda()]
            gibbs_masks = torch.ones(inputs.shape).cuda()
            prop_std = proposal_std
            acceptances = torch.empty(len(inputs),1,1,1).cuda()
            grid = make_grid(data.cpu(), nrow=2).permute(1,2,0)

            ax.imshow(grid)
            plt.savefig('./input.png')
            gibbs_prob = 0.5
            acceptance_count = 2.0
            trials = 10.0

            current_gibbs_index = 0
            perturbations = torch.empty(inputs.shape).cuda()
            for i in range(0, num_proposals):
                inputs = (255.0*data_inputs + perturbations.uniform_())/256.0
                full_original_proposals, old_resample_nll, y_logits = graph.full_normal_flow(projected_end, None, relative=True, resample_logdet=True)
                original_proposals2, old_resample_nll = graph.flow.partial_encode(projected_end)
                new_original_completions = projected_end.mul(1.0 - masks) + inputs.mul(masks)
                if i % 100 == 0:
                    acceptance_count = 0.0
                    trials = 0.0
                    grid = make_grid(new_original_completions.cpu(), nrow=2).permute(1,2,0)
                    ax.imshow(grid)
                    plt.savefig('./test%i.png' % i)                    
                old_comp_latents, original_nll, y_logits = graph.full_normal_flow(new_original_completions, None)
                for j in range(0, 3):
                    new_proposals[j].normal_(mean=0.0, std=1.0)
                    new_proposals[j] = new_proposals[j].mul(prop_std) + full_original_proposals[j]
                bayes_mod = 1e6
                full_new_proposals, new_resample_nll2, y_logits_new = graph.full_normal_flow(new_projected_end, None, relative=True, resample_logdet=True)
                new_completions = new_projected_end.mul(1.0 - masks) + inputs.mul(masks)
                new_comp_latents, new_nll, y_logits = graph.full_normal_flow(new_completions, None)
                bayes_prob = (bayes_mod)*(((projected_end - inputs).mul(masks)**2)/2.0 - ((new_projected_end - inputs).mul(masks)**2)/2.0).view(len(inputs), 3*64*64).sum(dim=1)/(auxiliary_std**2)
                acceptance_prob = torch.clamp(bayes_prob + new_nll  - original_nll, -50, 50)
                acceptance_prob = torch.exp(acceptance_prob).unsqueeze(1).unsqueeze(2).unsqueeze(3)
                
                acceptance_prob[acceptance_prob != acceptance_prob] = 0.0
                acceptances.uniform_()
                acceptances = (acceptances <= acceptance_prob).float()
                acceptances[acceptances != acceptances] = 0.0
                acceptances[new_nll != new_nll] = 0.0
                new_projected_end[new_projected_end != new_projected_end] = 0.0
                projected_end = new_projected_end.mul(acceptances) + projected_end.mul(1.0 - acceptances)
                print(i, float(acceptances.sum()), float(original_nll.mean()), float((old_comp_latents[2].view(len(inputs), 48*8*8)**2).sum(dim=1).mean()/2.0), float(((projected_end - inputs).mul(masks)**2/2.0).mean()))
                if float(((projected_end - inputs).mul(masks)**2/2.0).mean()) != float(((projected_end - inputs).mul(masks)**2/2.0).mean()):
                    print(new_projected_end.max(), new_projected_end.min())
                    print(acceptances.max(), acceptances.min())
                    print(acceptances[acceptances != acceptances])
                    print(projected_end[projected_end != projected_end])
                    break
                trials += 1.0
                acceptance_count += float(acceptances.sum())


            break
        
        
