'''MIVI for NB model
'''
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import numpy as np
import copy
import scipy
from scipy.io import loadmat
from scipy.stats import nbinom
#from keras.utils import to_categorical
import pandas as pd
import gc


def log_sum_exp(value, dim=None, keepdim=False):
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                            dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)   
        
#device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample_bet(L, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn([L]+list(mu.shape)).to(device)
    return mu + eps * std.unsqueeze(0)

def sample_logtau(L, mu_a, logvar_a): # sample from normal dist
    std = torch.exp(logvar_a/2)
    z = torch.randn(L).to(device)
    return (mu_a + z * std)

def adjust_learning_rate(optmz, epoch, init_lr=0.001, decay=0.9, decay_epoch=100):
    """Sets the learning rate to the initial LR decayed by 0.9 every 100 epochs"""
    lr = init_lr * (decay ** (epoch // decay_epoch))
    for param_group in optmz.param_groups:
        param_group['lr'] = lr

x = np.array(
    pd.read_csv('NB_data.csv').x
    )
x = torch.tensor(x, dtype=torch.float).to(device)
n = x.size(0)
#%% MIVI: Langevin + VI
# parameters
eps_size=10
z_size = 2
# NN to approximate gradient
class transition(nn.Module):
    def __init__(self, size=z_size, h_dim=20):
        super(transition, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, z_size).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h
transit = transition()    
optimizer_transit = torch.optim.Adam(transit.parameters(), lr=0.001)

    
# GAN discriminator:
class discriminator(nn.Module):
    def __init__(self, size=z_size, h_dim=10):
        super(discriminator, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, 1).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h
    
discr = discriminator()    
optimizer_discr= torch.optim.Adam(discr.parameters(), lr=0.001)
discr_losses = []
gen_losses = []

# Variational distribution
SIVI = False
mu=torch.zeros([2]).to(device).requires_grad_(True) # mean for q(theta1)
log_var=(torch.tensor(-5.)+torch.zeros([2])).to(device).requires_grad_(True) # log var of q(theta1)
optimizer_q = torch.optim.Adam([mu , log_var], lr=0.002)#, weight_decay=0.0001)#)# weight_decay is to contol the l2 penalty on parameters

num_epochs = 10000
LL = 1000
losses=[]

#%%  Langevin update + auto grad
MCMC_grad = True
learn_all_steps = False
decreasing_step = False
if decreasing_step:
    final_step_rate = 0.1
gamma = 0.55

LL = 1000
n_LD = 10
if learn_all_steps:
    logstep_size = (torch.rand([2, n_LD])*0.1-0.5-0.1*
                torch.tensor(range(1,n_LD+1), dtype=torch.float)).to(device).requires_grad_(True)
else:
    logstep_size = (torch.zeros([2]) -6).to(device).requires_grad_(True)

optimizer = torch.optim.Adam([mu , log_var, logstep_size], lr=0.001)#, weight_decay=0.0001)#)# weight_decay is to contol the l2 penalty on parameters
optimizer_stepsize = torch.optim.Adam([logstep_size], lr=0.001)#, weight_decay=0.0001)#)# weight_decay is to contol the l2 penalty on parameters

num_epochs = 12000 
#label = torch.tensor([0.01]*LL  + [0.99]*LL ).unsqueeze(-1).to(device) 
label = torch.tensor([0.01]*LL*n_LD + [0.99]*LL*n_LD).unsqueeze(-1).to(device) 
losses=[]

NAN = False
TooBig = False    
banana2 = False
#%% Langevin: learn step size
for epoch in range(num_epochs):
    LD_step_size = logstep_size.exp()
    z = sample_bet(LL, mu, log_var)
    z_tilde = z.clone()
    zs = torch.zeros([n_LD, LL, 2]).to(device)
    # sample from q_tilde
    if learn_all_steps == False:
        step_size = LD_step_size
        if decreasing_step:
            bb = n_LD/((step_size/(final_step_rate*step_size)).pow(1/gamma)-1)
            aa = step_size*bb.pow(gamma)      
    for iii in range(n_LD):
        if learn_all_steps: # learn step size of every LD updates
            if decreasing_step:
                step_size = LD_step_size[:,iii::].sum(-1)
            else:
                step_size = LD_step_size[:,iii]
        elif decreasing_step: 
            # only learn the step size of the first LD update, 
            # and have decreasing step size
            step_size = aa/(bb+iii).pow(gamma)
#       # auto grad
        r = z_tilde[:,0].exp().unsqueeze(-1)
        p = z_tilde[:,1].sigmoid().unsqueeze(-1)
        loglh = (torch.lgamma(x+r)-torch.lgamma(x+1)-torch.lgamma(r) + \
                r*p.log() + x*(1-p).log()).sum() -(0.5*z_tilde.pow(2)).sum() 
        logprior = -0.5*(z_tilde.pow(2)).mean(0).sum()
        logmarginal = loglh + logprior
        dz = torch.autograd.grad(logmarginal, z_tilde, create_graph=True)[0]
        delta_z = 0.5*step_size*dz +\
                    step_size.sqrt()*torch.randn([LL,2]).to(device)
        if torch.isnan(delta_z).sum()>0: # or (delta_z.abs()>100).sum()>0:
            print("NAN")
            NAN = True
            break
        z_tilde = z_tilde+delta_z
        if torch.isnan(z_tilde).sum()>0: # or (z_tilde.abs()>100).sum()>0:
            print("too big")
            TooBig = True
            break
        zs[iii,:,:]=z_tilde.clone()
    zs = zs.view([-1,2])    

    if NAN or TooBig:
        break
    # Start GAN
    for _ in range(1):
    # train q
     for _ in range(1): 
      # minimize cross entropy
      logq_z = ( -0.5*log_var - 0.5/log_var.exp()*(zs.detach()-mu).pow(2)
                     ).mean(0).sum()
      gener_loss = -logq_z
      optimizer_q.zero_grad()
      gener_loss.backward(retain_graph=True)
      optimizer_q.step()
     for _ in range(1): 
      z = sample_bet(LL*n_LD, mu, log_var)  
      pars = torch.cat([z.detach(), zs.detach()], dim=0)
      phat = discr(pars).sigmoid()
      d_loss = F.binary_cross_entropy(phat, label)#, reduction='none').mean()
      optimizer_discr.zero_grad()
      d_loss.backward(retain_graph=False)
      optimizer_discr.step() 
     discr_losses.append(d_loss.item())
     gen_losses.append(gener_loss.item())
    # train the main model
    rs = zs[:,0].exp().unsqueeze(-1)
    ps = zs[:,1].sigmoid().unsqueeze(-1)
    loglh = (torch.lgamma(x+rs)-torch.lgamma(x+1)-torch.lgamma(rs) + \
                rs*ps.log() + x*(1-ps).log()).mean(0).sum() 

    logq_z = ( -0.5*log_var - 0.5/log_var.exp()*(zs-mu).pow(2)
                     ).mean(0).sum()
    logprior = -0.5*(zs.pow(2)).mean(0).sum()
    KL = discr(zs).mean()
    loss = KL - loglh + logq_z - logprior 
    optimizer_stepsize.zero_grad()    
    loss.backward()
    optimizer_stepsize.step()
    losses.append(loss.item())
    if  epoch%200==1:# and (i+1) %250 ==0:
            print ("Epoch[{}/{}], logq_bet: {:.4f}, loss: {:.4f}, loglh: {:.4f},\
                   discr_loss: {:.4f}, loglhratio: {:.4f}"#, MM: {:.4f}" 
                   .format(epoch+1, num_epochs, logq_z.item(), 
                           loss.item(), loglh.item(), d_loss.item(), 
                           KL.item()#, MM.item()
                           )        
                   )

plt.plot(gen_losses)     
plt.plot(discr_losses)              
plt.plot(losses) 
#%% plot p and r
from scipy.stats import kde
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
nbins=100
# from q distribution
z = sample_bet(LL*n_LD, mu, log_var)
r = z[:,0].exp().unsqueeze(-1)
p = z[:,1].sigmoid().unsqueeze(-1)
z1 = r.cpu().detach().squeeze(-1); z2 = 1. - p.cpu().detach().squeeze(-1)
k_q = kde.gaussian_kde([z2.numpy(), z1.numpy()])
xi, yi = np.mgrid[(z2.numpy().min()  ): z2.numpy().max():nbins*1j, 
                  (z1.numpy().min()  ): (z1.numpy().max()):nbins*1j]
xi, yi = np.mgrid[(0.66): 0.74: nbins*1j,  (1.7): (2.3):nbins*1j] # banana
z_q = k_q(np.vstack([xi.flatten(), yi.flatten()]))

# from \tilde q distribution
z1 = rs.cpu().detach().squeeze(-1); z2 = 1. - ps.cpu().detach().squeeze(-1)
k_qtilde = kde.gaussian_kde([z2.numpy(), z1.numpy()])
xi, yi = np.mgrid[(z2.numpy().min()  ): z2.numpy().max():nbins*1j, 
                  (z1.numpy().min()  ): (z1.numpy().max()):nbins*1j]
xi, yi = np.mgrid[(0.66): 0.74: nbins*1j,  (1.7): (2.3):nbins*1j] # banana
z_qtilde = k_qtilde(np.vstack([xi.flatten(), yi.flatten()]))

plt.figure()
plt.contour(xi, yi, z_q.reshape(xi.shape))
plt.contour(xi, yi, z_qtilde.reshape(xi.shape))

