import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

from IPython.display import Image

import numpy as np
from os.path import exists

bs = 512
epochs = 1000

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

def flatten(x):
    return Variable(x.reshape(x.size(0), -1))

def save_image(x, path='real_image.png'):
    torchvision.utils.save_image(x, path)

class VAE(nn.Module):
    def __init__(self, image_size=32*32*10*3, h_dim=1024, z_dim=20):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(image_size, h_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(h_dim, z_dim*2)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, image_size),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = to_var(torch.randn(*mu.size()))
        z = mu + std * esp
        return z
    
    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = torch.chunk(h, 2, dim=1)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu**2 -  logvar.exp())
    return BCE + KLD, BCE, KLD

def save_model(model, blind, epoch):
    assert blind in ['color', 'shape']
    ckpt_path='./models/vae/vae_{}_{}.pkl'.format(blind, epoch)
    print("saving model to %s..." % ckpt_path)
    torch.save(model.state_dict(), ckpt_path)

def load_model(model, blind, epoch):
    assert blind in ['color', 'shape']

    ckpt_path='./models/vae/vae_{}_{}.pkl'.format(blind, epoch)
    assert exists(ckpt_path), "epoch misspecified"
    print("loading model from %s..." % ckpt_path)
    model.load_state_dict(torch.load(ckpt_path))

if __name__ == '__main__':
    blind = 'color'
    images_large = np.load('./data/{}blind_images_large.npz'.format(blind))['arr_0']

    data_size = images_large.shape[0]
    n_batch = data_size // bs
    # images_large = images_large[:, np.newaxis, :, :]
    images_large = torch.from_numpy(images_large)
    # save_image(images_large[:bs])

    images_large = flatten(images_large)
    images_large = images_large.float()

    vae = VAE()
    if torch.cuda.is_available():
        vae.cuda()

    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for idx in range(n_batch):
            images = images_large[bs*idx:bs*(idx+1)]
            images = images.cuda()
            recon_images, mu, logvar = vae(images)
            loss, BCE, KLD = loss_fn(recon_images, images, mu, logvar)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                    
        print("Epoch[{}/{}] Loss: {:.3f} BCE: {:.3f} KLD: {:.3f}".format(epoch, epochs, loss.data/bs, BCE.data/bs, KLD.data/bs))
        if epoch > 900 and epoch % 10 == 0:

            # recon_x, _, _ = vae(images_large[:bs])
            # save_image(recon_x.view(recon_x.size(0), 1, 32, 32).data.cpu(), f'reconstructed/recon_image_{epoch}_{idx}.png')

            save_model(vae, epoch)
