'''VAE, use langevine to update q_theta(z|x)
'''
#%% Server
from __future__ import division
from __future__ import print_function
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import copy
import itertools
import torchvision.utils as utils
import scipy
from scipy.io import loadmat
#from keras.utils import to_categorical
import pandas as pd
import gc
import argparse

def adjust_learning_rate(optmz, epoch, init_lr=0.001, decay=0.9, decay_epoch=100):
    """Sets the learning rate to the initial LR decayed by 0.9 every 100 epochs"""
    lr = init_lr * (decay ** (epoch // decay_epoch))
    for param_group in optmz.param_groups:
        param_group['lr'] = lr

def sample_bet(L, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn([L]+list(mu.shape)).to(device)
    return mu + eps * std.unsqueeze(0)   

def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                            dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)
     

pwd= '/result'
device = torch.device("cuda:2")
# testing data
dataset_test = torchvision.datasets.MNIST(root=pwd,
                                     train=False,
                                     transform=transforms.ToTensor(),
                                     download=True)

nt=10000
dataset_test.data = dataset_test.data[0:nt]
# stochastically binarized testing data
torch.manual_seed(0)
m = torch.distributions.bernoulli.Bernoulli(probs=dataset_test.data.float()/255)
temp = m.sample()
dataset_test.data=(temp*255).type(torch.ByteTensor)
#
data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test,
                                          batch_size=100,#nt, 
                                          shuffle=False)
######################################################
# training data 
p = 784
dataset = torchvision.datasets.MNIST(root=pwd,
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
dataset.data = dataset.data[0:50000]
batch_size = 2000  
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)
n = len(dataset) 
#%% Regular VAE
if True:
# parameters
 eps_size=0
 x_size = p
 z_dim = 10 #int(p/10)
 h_dim = 200
 L = 5
 class VAE(nn.Module):
    def __init__(self, x_size=x_size , h_dim=h_dim, 
                 z_dim=z_dim, L=L):
        super(VAE, self).__init__()
        # MLP for mean and log_var of encoder N(z;mean, var)
        self.L=L
        self.encoder = nn.Sequential(
            nn.Linear(x_size, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )  
        self.fc1_m = nn.Linear(h_dim, z_dim)
        self.fc1_logv = nn.Linear(h_dim, z_dim)
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim,  h_dim),
            nn.ReLU(),
            nn.Linear( h_dim, x_size)
        )
    def encode(self, x): 
        h = self.encoder(x)
        return self.fc1_m(h), self.fc1_logv(h)
    def reparameterize(self, mu, log_var, L):
        std = torch.exp(log_var/2)
        eps = torch.randn([L]+list(mu.shape)).to(device)
        return mu + eps * std.unsqueeze(0)
    def decode(self, z):
        h = self.decoder(z)
        return torch.sigmoid(h)
    def forward(self, x, L):
        z_mu, z_logvar = self.encode(x)
        z = self.reparameterize(z_mu, z_logvar, L)
        x_reconst = self.decode(z)
        return x_reconst, z, z_mu, z_logvar


#%% MIVI
# parameters
eps_size=0
x_size = p
z_dim = 10 #int(p/10)
h_dim = 200
L = 10
class VAE(nn.Module):
    def __init__(self, x_size=x_size , h_dim=h_dim, 
                 z_dim=z_dim):
        super(VAE, self).__init__()
        # MLP for mean and log_var of encoder N(z;mean, var)
        self.encoder_mean = nn.Sequential(
            nn.Linear(x_size, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, z_dim)
        ) # dim=n*800
        self.encoder_logv = nn.Sequential(
            nn.Linear(x_size, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, z_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, x_size)
        )
    def encode(self, x): 
        return self.encoder_mean(x), self.encoder_logv(x)
    def reparameterize(self, mu, log_var, L):
        std = torch.exp(log_var/2)
        eps = torch.randn([L]+list(mu.shape)).to(device)
        return mu + eps * std.unsqueeze(0)
    def decode(self, z):
        h = self.decoder(z)
        return h#torch.sigmoid(h)
    def forward(self, x, L):
        z_mu, z_logvar = self.encode(x)
        z = self.reparameterize(z_mu, z_logvar, L)
        x_reconst = self.decode(z)
        return x_reconst, z, z_mu, z_logvar  

class stepsize(nn.Module):
    def __init__(self, x_size=x_size, z_dim=z_dim, h_dim=200): 
        super(stepsize, self).__init__()
        self.step = nn.Sequential(
            nn.Linear(x_size, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, z_dim)
        ) # dim=n*800
    def forward(self, x):
        logstepsize = self.step(x) 
        return logstepsize

# discriminator:
class discriminator(nn.Module):
    def __init__(self, size=z_dim+p, h_dim=200): 
        super(discriminator, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, 1).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h
            

#%%  Langevin update + auto grad
model = VAE().to(device)
logstepsize = stepsize().to(device)
discr = discriminator()    
step_decode = []
step_decode.extend(logstepsize.parameters())
step_decode.extend(model.decoder.parameters())
param_q = []
param_q.extend(model.encoder_logv.parameters())
param_q.extend(model.encoder_mean.parameters())
NAN = False
TooBig = False 
num_epochs = 2000 
n_LD = 5
losses=[]
discr_losses = []
gen_losses = []
label = torch.tensor([0.01]*L*n_LD + [0.99]*L*n_LD).unsqueeze(-1).to(device) 
lr = 0.00002   
optimizer_elbo = torch.optim.Adam(step_decode, lr=lr)
optimizer_q = torch.optim.Adam(param_q, lr=lr)
optimizer_discr= torch.optim.Adam(discr.parameters(), lr=lr) 
M = 0


# load the pre-trained model instead (optional)
#checkpoint = torch.load('mnist_50k_state_dict_new')
#model.load_state_dict(checkpoint['model_state_dict'])
#discr.load_state_dict(checkpoint['discr_state_dict'])
#logstepsize.load_state_dict(checkpoint['stepsize_state_dict'])
#optimizer_elbo.load_state_dict(checkpoint['optimizer_elbo_state_dict'])
#optimizer_q.load_state_dict(checkpoint['optimizer_q_state_dict'])
#optimizer_discr.load_state_dict(checkpoint['optimizer_discr_state_dict'])
#loglh = checkpoint['loglh']



#%% Langevin: learn step size
for epoch in range(num_epochs):
 adjust_learning_rate(optimizer_elbo, epoch, init_lr=0.001, decay=0.88, decay_epoch=100)
 adjust_learning_rate(optimizer_q, epoch, init_lr=0.002, decay=0.88, decay_epoch=100)
 adjust_learning_rate(optimizer_discr, epoch, init_lr=0.001, decay=0.88, decay_epoch=100)
 for i, (x, _) in enumerate(data_loader):
    # Forward pass
    x = x.to(device).view(-1, x_size)
    ## stochastically binarize data
    m = torch.distributions.bernoulli.Bernoulli(probs=x)
    x = m.sample()
    ##
    x_reconst, z, z_mu, z_log_var = model.forward(x, L)
    bs = x.size(0)
    z_tilde = z.clone()
    zs = torch.zeros([n_LD, L, bs, z_dim]).to(device)
    # sample from q_tilde
    step_size = logstepsize(x).exp()#*0.01 
    for iii in range(n_LD):
#       # auto grad
        x0 = model.decode(z_tilde)
        loglh0 = -F.binary_cross_entropy_with_logits(x0, x.repeat(L,1,1),
                                              reduction='none').sum()#*n/bs
        logprior0 = -0.5*z_tilde.pow(2).sum()#*n/bs
        logmarginal = loglh0 + logprior0
        dz = torch.autograd.grad(logmarginal, z_tilde, create_graph=True#False#
                                 )[0]
        delta_z = 0.5*step_size*dz +\
                    step_size.sqrt()*torch.randn([L,bs,z_dim]).to(device)
                        
        if torch.isnan(delta_z).sum()>0: # or (delta_z.abs()>100).sum()>0:
            print("NAN")
            NAN = True
            break
        z_tilde = z_tilde+delta_z
        if torch.isnan(z_tilde).sum()>0: # or (z_tilde.abs()>100).sum()>0:
            print("too big")
            TooBig = True
            break
        zs[iii,:,:]=z_tilde.clone()
    if NAN or TooBig:
        break
    zs = zs.view([-1,bs,z_dim])    
    # Start GAN
    for _ in range(1):
    # train q
     for _ in range(1): 
      z_mu, z_logvar = model.encode(x)
      logq_z = ( -0.5*z_log_var - 0.5/z_log_var.exp()*(zs.detach()-z_mu).pow(2)
                     ).mean([0,1]).sum()#*n/bs
      gener_loss = -logq_z
      optimizer_q.zero_grad()
      gener_loss.backward(retain_graph=True)
      optimizer_q.step()
    # train discriminator
     d_loss = torch.tensor(0.1)
     if epoch>=M: 
      for _ in range(1): 
       z_mu, z_logvar = model.encode(x)   
       z = sample_bet(L*n_LD, z_mu, z_log_var)  
       zx = torch.cat([z, x.repeat([L*n_LD,1,1])], dim=-1)
       zsx = torch.cat([zs, x.repeat([L*n_LD,1,1])], dim=-1)
       pars = torch.cat([zx.detach(), zsx.detach()], dim=0)
       phat = discr(pars).squeeze(-1)#.sigmoid()
       d_loss = F.binary_cross_entropy_with_logits(phat, label.repeat([1,bs]))#, reduction='none').mean()
       optimizer_discr.zero_grad()
       d_loss.backward(retain_graph=True)
       optimizer_discr.step() 
    x_reconst = model.decode(zs)
    loglh = -F.binary_cross_entropy_with_logits(x_reconst, x.repeat(L*n_LD,1,1),
                                        reduction='none').mean([0,1]).sum()#*n/bs 
    logprior = -0.5*zs.pow(2).mean([0,1]).sum()#*n/bs
    logq_z = ( -0.5*z_log_var - 0.5/z_log_var.exp()*(zs-z_mu).pow(2)
                     ).mean([0,1]).sum()#*n/bs
    if epoch>=M:
      KL = discr(zsx).mean([0,1]).sum()#*n/bs
    else:
      KL = torch.tensor(0.1)
    loss =  - loglh  - logprior + KL + logq_z 
    optimizer_elbo.zero_grad()
    loss.backward()
    optimizer_elbo.step()
    losses.append(loss.item())
    if i%150 == 0  and epoch % 20 == 0 :
            print ("Epoch[{}/{}], logq_bet: {:.4f}, loss: {:.4f}, loglh: {:.4f}, discr_loss: {:.4f}, loglhratio: {:.4f}"#, MM: {:.4f}" 
                   .format(epoch+1, num_epochs, logq_z.item(), 
                           loss.item(), loglh.item(), d_loss.item(), 
                           KL.item()#, MM.item()
                           )        
                   )
            with torch.no_grad():
            # Save the sampled images
                z = torch.randn(100, z_dim).to(device)#[0,:]
                out = model.decode(z).view(-1, 1, 28, 28)
                save_image(out.sigmoid(), os.path.join(pwd,
                                         'sampled-{}.png'.format(epoch+2)))
                # Save the reconstructed images
                temp = min(100, batch_size)
                out, _, _, _ = model(x[0:temp,:],2)
                x_concat = torch.cat([x[0:temp,:].view(-1, 1, 28, 28), 
                                  out[0,:,:].sigmoid().view(-1, 1, 28, 28)], dim=3)
                save_image(x_concat, os.path.join(pwd,
                                          'reconst-{}.png'.format(epoch+2)))
    if epoch%100==0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'stepsize_state_dict': logstepsize.state_dict(),
	    'discr_state_dict': discr.state_dict(),	
	    'optimizer_elbo_state_dict': optimizer_elbo.state_dict(),	
	    'optimizer_q_state_dict': optimizer_q.state_dict(),
	    'optimizer_discr_state_dict': optimizer_discr.state_dict(),
            'loglh': loglh 
            }, pwd+'/mnist_50k_state_dict_new')
 if NAN or TooBig:
    break
    

#%%
 


#%% Testing
# Average expected log marginal likelihood by importance sampling.
# Use normality of SGLD
LD = 5# number of SGLD steps
temp = 0.
temp3 = 0.
LL = 1000 # This is J
L_eps=50 # This is K
for i, (x, _) in enumerate(data_loader_test):
    # Forward pass
    if i%200==0:
        print(i)
    x = x.to(device).view(-1, x_size)
    x_reconst, z, z_mu, z_log_var = model.forward(x,LL)
    bs = x.size(0)
    loglh2 = -F.binary_cross_entropy_with_logits(x_reconst, x.repeat(LL,1,1),
                                    reduction='none').sum(-1) 
    logprior2 = -0.5*z.pow(2).sum(-1)
    logq2 = (-0.5*z_log_var-0.5*(z-z_mu).pow(2)/z_log_var.exp()).sum(-1)
    log_elbo2 = loglh2 + logprior2 - logq2  
    logmarginal2 = (log_sum_exp( log_elbo2, dim=0)-np.log(LL))
    temp += logmarginal2.mean().item()
#    print("no mcmc loglh: {:.4f}".format(logmarginal2.mean().item()))
    z_tilde = z.clone()
    zs = torch.zeros([LD, bs, z_dim]).to(device)
    # sample from q_tilde
    step_size = logstepsize(x).exp()#*0.01
    logvar_tilde = step_size.log()
    for iii in range(LD):
        x0 = model.decode(z_tilde)
        if iii==0:
            loglh0 = -F.binary_cross_entropy_with_logits(x0, x.repeat(LL,1,1),
                                              reduction='none').sum(-1)
        else:
            loglh0 = -F.binary_cross_entropy_with_logits(x0, 
                                                         x.repeat(L_eps,LL,1,1),
                                              reduction='none').sum(-1)
        logprior0 = -0.5*z_tilde.pow(2).sum()
        logmarginal = loglh0.sum() + logprior0
        dz = torch.autograd.grad(logmarginal, z_tilde, create_graph=False#True#
                                 )[0]
        delta_z = 0.5*step_size*dz +\
                    step_size.sqrt()*torch.randn([L_eps,LL,bs,z_dim]).to(device)
#        if torch.isnan(delta_z).sum()>0: # or (delta_z.abs()>100).sum()>0:
#            print("NAN")
#            NAN = True
#            break
        mu_tilde = z_tilde+0.5*step_size*dz
        z_tilde = z_tilde+delta_z
#        if torch.isnan(z_tilde).sum()>0: # or (z_tilde.abs()>100).sum()>0:
#            print("too big")
#            TooBig = True
#            break
        x_mcmc = model.decode(z_tilde[0])
        if iii == LD-1:
         with torch.no_grad():
          loglh2 = -F.binary_cross_entropy_with_logits(x_mcmc, 
                                                       x.repeat(LL,1,1),
                                    reduction='none').sum(-1) 
          logprior2 = -0.5*z_tilde[0].pow(2).sum(-1)
          logq2 = (log_sum_exp((-0.5*logvar_tilde-
                   0.5*(z_tilde[0,]-mu_tilde).pow(2)/logvar_tilde.exp()
                   ),dim=0)-np.log(L_eps)).sum(-1)
          log_elbo2 = loglh2 + logprior2 - logq2  
          logmarginal2 = (log_sum_exp( log_elbo2, dim=0)-np.log(LL)).mean()
          if iii == LD-1:
              temp3 = temp3+logmarginal2.item()

print("(overall) no mcmc: {:.4f}".format(temp/len(data_loader_test)))   
print("(overall) last mcmc iter: {:.4f}".format(temp3/len(data_loader_test)))    
  


#%%#####################################################################################3
