'''MIVI for the toy experiments in Appendix.
Simply copy and paste everything in this file to a Python3 console.
'''
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import numpy as np
import copy
import scipy
from scipy.io import loadmat
#from keras.utils import to_categorical
import pandas as pd
import gc
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation

    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                            dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)   
        
#device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample_bet(L, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn([L]+list(mu.shape)).to(device)
    return mu + eps * std.unsqueeze(0)

def sample_logtau(L, mu_a, logvar_a): # sample from normal dist
    std = torch.exp(logvar_a/2)
    z = torch.randn(L).to(device)
    return (mu_a + z * std)

eps_size=10
z_size = 2
class discriminator(nn.Module):
    def __init__(self, size=z_size, h_dim=10):
        super(discriminator, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, 1).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h
    
discr = discriminator()    
discr_losses = []
gen_losses = []

mu=torch.zeros([2]).to(device).requires_grad_(True) # mean for q(theta1)
log_var=(torch.tensor(0.5)+torch.zeros([2])).to(device).requires_grad_(True) # log var of q(theta1)

#%%  Langevin update + auto grad
correlate = True
banana= False
mixture2 = False
logstep_size = (torch.zeros([2]) -1 ).to(device).requires_grad_(True)
if correlate:
  Sigma = torch.tensor([[1,0.8],[0.8,1]]).to(device)
  prec = Sigma.inverse()
if mixture2:
  mu1 = torch.tensor([-1.,-1.]).to(device)
  Sigma1 = torch.tensor([[1,-0.5],[-0.5,1]]).to(device)
  prec1 = Sigma1.inverse()
  mu2 = torch.tensor([1.3,1.3]).to(device)
  Sigma2 = torch.tensor([[1.,0.3],[0.3,1.]]).to(device)
  prec2 = Sigma2.inverse()
if False: # plot ground truth of mixture2
    from scipy.stats import kde
    n = 5000
    nbins = 150
    lower1 = Sigma1.cholesky().cpu()
    lower2 = Sigma2.cholesky().cpu()
    z1 = torch.matmul(lower1, torch.randn([2,n]) ).t() + mu1.cpu()
    z2 = torch.matmul(lower2, torch.randn([2,n]) ).t() + mu2.cpu()
    b = torch.tensor(torch.rand([n,1])>0.5,dtype=torch.float) 
    z = z1*b + z2*(1-b)
    zplt = z.cpu().detach()
    z1 = zplt[:,0]; z2 = zplt[:,1]
    k_mcmc = kde.gaussian_kde([z2.numpy(), z1.numpy()])
    xi, yi = np.mgrid[(-4): 4: nbins*1j,  (-4): (4):nbins*1j] # banana
    z_mcmc = k_mcmc(np.vstack([xi.flatten(), yi.flatten()]))
    plt.figure()
    plt.contour(xi, yi, z_mcmc.reshape(xi.shape))
    plt.figure()
    plt.pcolormesh(xi, yi, z_mcmc.reshape(xi.shape))#, cmap=plt.cm.Greens_r)

num_epochs = 2000 
LL = 500
n_LD =  20   
label = torch.tensor([0.01]*LL*n_LD + [0.99]*LL*n_LD).unsqueeze(-1).to(device) 
losses=[]
NAN = False
TooBig = False    
optimizer = torch.optim.Adam([mu , log_var, logstep_size], lr=0.00005)
optimizer_stepsize = torch.optim.Adam([logstep_size], lr=0.00005) 
optimizer_q = torch.optim.Adam([mu , log_var], lr=0.00005) 
optimizer_discr= torch.optim.Adam(discr.parameters(), lr=0.00005)
#%% Langevin: learn step size
for epoch in range(2000):
    step_size = logstep_size.exp()
    z = sample_bet(LL, mu, log_var)
    z_tilde = z.clone()
    zs = torch.zeros([n_LD, LL, 2]).to(device)
    # sample from q_tilde
    for iii in range(n_LD):
#       # auto grad
        if banana:
            loglh =  (-0.5*(z_tilde[:,0]-z_tilde[:,1].pow(2)/4).pow(2)-
                  0.5*z_tilde[:,1].pow(2)/4).sum() # banana
        elif correlate:
            loglh = -0.5*(torch.matmul(z_tilde, prec)*z_tilde).sum(-1).sum()
        elif mixture2:
            expo1 = np.log(0.5)-0.5*Sigma1.det().log()-0.5*(
                    torch.matmul(z_tilde-mu1, prec1)*(z_tilde-mu1)).sum(-1)  
            expo2 = np.log(0.5)-0.5*Sigma2.det().log()-0.5*(
                    torch.matmul(z_tilde-mu2, prec2)*(z_tilde-mu2)).sum(-1)
            loglh = log_sum_exp(
                        torch.cat([expo1.unsqueeze(-1), expo2.unsqueeze(-1)], 
                                   dim=-1), 
                        dim = -1
                    ).sum()    
        else:    
            loglh = log_sum_exp(
                    torch.cat(
                    [(-0.5*z_tilde[:,1].pow(2)-0.5*(z_tilde[:,0]+2).pow(2)).unsqueeze(-1), 
                     (-0.5*z_tilde[:,1].pow(2)-0.5*(z_tilde[:,0]-2).pow(2)).unsqueeze(-1)
                    ], dim=-1), dim = -1
                    ).sum()  # normal mixture
        dz = torch.autograd.grad(loglh, z_tilde, create_graph=True)[0]
        delta_z = 0.5*step_size*dz +\
                    step_size.sqrt()*torch.randn([LL,2]).to(device)
        if torch.isnan(delta_z).sum()>0: # or (delta_z.abs()>100).sum()>0:
            print("NAN")
            NAN = True
            break
        z_tilde = z_tilde+delta_z
        if torch.isnan(z_tilde).sum()>0: # or (z_tilde.abs()>100).sum()>0:
            print("too big")
            TooBig = True
            break
        zs[iii,:,:]=z_tilde.clone()
    zs = zs.view([-1,2]) 
    if NAN or TooBig:
        break
    for _ in range(2):
     for _ in range(1): 
      logq_z = ( -0.5*log_var - 0.5/log_var.exp()*(zs.detach()-mu).pow(2)
                     ).mean(0).sum()
      gener_loss = -logq_z
      optimizer_q.zero_grad()
      gener_loss.backward()
      optimizer_q.step()
     for _ in range(3): 
      z = sample_bet(LL*n_LD, mu, log_var)  
      pars = torch.cat([z.detach(), zs.detach()], dim=0)
      phat = discr(pars).sigmoid()
      d_loss = F.binary_cross_entropy(phat, label)#, reduction='none').mean()
      optimizer_discr.zero_grad()
      d_loss.backward(retain_graph=False)
      optimizer_discr.step() 
     discr_losses.append(d_loss.item())
     gen_losses.append(gener_loss.item())
    if banana:
      loglh =  (-0.5*(zs[:,0]-zs[:,1].pow(2)/4).pow(2)-
                  0.5*zs[:,1].pow(2)/4).mean() # banana
    elif correlate:
      loglh = -0.5*(torch.matmul(zs, prec)*zs).sum(-1).mean()
    elif mixture2:
      expo1 = (0.5*Sigma1.det().pow(-0.5)).log()-0.5*(
              torch.matmul(zs-mu1, prec1)*(zs-mu1)).sum(-1)  
      expo2 = (0.5*Sigma2.det().pow(-0.5)).log()-0.5*(
              torch.matmul(zs-mu2, prec2)*(zs-mu2)).sum(-1)
      loglh = log_sum_exp(
                 torch.cat([expo1.unsqueeze(-1), expo2.unsqueeze(-1)], dim=-1), 
                 dim = -1
              ).mean(0) 
    else:    
      loglh = log_sum_exp(
            torch.cat(
                    [(-0.5*zs[:,1].pow(2)-0.5*(zs[:,0]+2).pow(2)).unsqueeze(-1), 
                     (-0.5*zs[:,1].pow(2)-0.5*(zs[:,0]-2).pow(2)).unsqueeze(-1)], 
                    dim=-1
                     ), dim = -1
                       ).mean(0)  # normal mixture
    logq_z = ( -0.5*log_var - 0.5/log_var.exp()*(zs-mu).pow(2)
                     ).mean(0).sum()
    KL = discr(zs).mean()
    logprior = -0.5*(zs.pow(2)).mean(0).sum()
    loss = KL - loglh + logq_z - logprior #+ 0.01/step_size.pow(2).sum()
    
    optimizer_stepsize.zero_grad()    
    loss.backward()
    optimizer_stepsize.step()
    losses.append(loss.item())
    if  epoch%200==1:# and (i+1) %250 ==0:
            print ("Epoch[{}/{}], logq_bet: {:.4f}, loss: {:.4f}, loglh: {:.4f},\
                   discr_loss: {:.4f}, loglhratio: {:.4f}"#, MM: {:.4f}" 
                   .format(epoch+1, num_epochs, logq_z.item(), 
                           loss.item(), loglh.item(), d_loss.item(), 
                           KL.item()#, MM.item()
                           )        
                   )
                   
plt.plot(losses) 
#%%
from scipy.stats import kde
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
nbins=150
z = sample_bet(LL*n_LD, mu, log_var)  
zplt =  zs.cpu().detach()  # zs.cpu().detach() # z.cpu().detach() # 
#zplt =  z_tilde.cpu().detach()  # zs.cpu().detach() # z.cpu().detach() # 
z1 = zplt[:,0]; z2 = zplt[:,1]
k_mcmc = kde.gaussian_kde([z2.numpy(), z1.numpy()])
xi, yi = np.mgrid[(z2.numpy().min()  ): z2.numpy().max():nbins*1j, 
                  (z1.numpy().min()  ): (z1.numpy().max()):nbins*1j]

xi, yi = np.mgrid[(-4): 4: nbins*1j,  (-4): (4):nbins*1j] # banana
z_mcmc = k_mcmc(np.vstack([xi.flatten(), yi.flatten()]))
plt.figure()
plt.contour(xi, yi, z_mcmc.reshape(xi.shape))
