""" 
Bayesian Bridge regression using Data augmentation by Mike West 1987.
Use Weibull to approximate marginal distributions of lambda.
"""
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import random
import numpy as np
import copy
import scipy
import sklearn.datasets
from scipy.io import loadmat
import pandas as pd
import gc
def log_sum_exp(value, dim=None, keepdim=False):
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                            dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)   
        
        
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample_bet(L, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn([L]+list(mu.shape)).to(device)
    return mu + eps * std.unsqueeze(0)


def adjust_learning_rate(optmz, epoch, init_lr=0.001, decay=0.9, decay_epoch=100):
    lr = init_lr * (decay ** (epoch // decay_epoch))
    for param_group in optmz.param_groups:
        param_group['lr'] = lr


dat = sklearn.datasets.load_diabetes(return_X_y=False)
x = torch.tensor(dat.data, dtype=torch.float).to(device)
y = torch.tensor(dat.target, dtype=torch.float).to(device)
dat={}
dat['x']=x.cpu().numpy()
dat['y']=y.cpu().numpy()
y = y-y.mean()
n, p = x.shape

'''
Regard alpha and nu as known constant
'''
# parameters
alpha=1.
nu = torch.tensor([0.35]).to(device)
lognu = nu.log()
mu_b=torch.zeros([p]).to(device).requires_grad_(True) # normal mean of q(bet)
logvar_b=(torch.tensor(0.)+torch.zeros([p])).to(device).requires_grad_(True) # log var of q(bet)
mu_logtau = torch.zeros(1).to(device).requires_grad_(True) # normal mu of q(log_tau) where y~N(xbeta, 1/tau)
logvar_logtau = torch.zeros(1).to(device).requires_grad_(True) # log var of q(log_tau)

class discriminator(nn.Module): # discriminator for beta. So far we don't have one for logtau
    def __init__(self, size=p, h_dim=int(1.5*p)):
        super(discriminator, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, 1).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h

discr = discriminator()    

def rWeibull(L, a, b): # sample Weibull r.v.
    u = torch.rand([L]+list(a.shape)).to(device)
    loglogu = (-u.log()).log()
    return a*(loglogu/b).exp()


loga = (torch.tensor(0.)+torch.zeros([p])).to(device).requires_grad_(True) # Weibull parameters
logb = (torch.tensor(0.)+torch.zeros([p])).to(device).requires_grad_(True) # Weibull parameters
##############################################################################
# optimizers
optimizer_eta = torch.optim.Adam([loga, logb], lr=0.001)
optimizer_q_tau = torch.optim.Adam([mu_logtau, logvar_logtau], lr=0.001)
optimizer_q_bet = torch.optim.Adam([mu_b, logvar_b], lr=0.001)#, weight_decay=0.0001)#)# weight_decay is to contol the l2 penalty on parameters
optimizer_discr= torch.optim.Adam(discr.parameters(), lr=0.001)
NAN = False
TooBig = False
LL = 100
I = torch.eye(p).to(device)
num_epochs = 2000
losses=[]
discr_losses=[]
c = d = 1. # Gamma(c,d) prior for tau=1/sig2 
bs=n
n_LD = 3 # number of MCMC updates initialized from q(beta) and q(tau)
label = torch.tensor([0.01]*LL*n_LD + [0.99]*LL*n_LD).unsqueeze(-1).to(device) 
ee = torch.distributions.gamma.Gamma(c+n/2+p/2, 1.)
xt = x.transpose(0,1) # p*n
xpy = torch.matmul(xt, y).unsqueeze(-1) # p*1
xpx = torch.matmul(xt,x)
#%%
'''
MIVI, regard nu as a constant
'''
for epoch in range(num_epochs*5):
        adjust_learning_rate(optimizer_eta, epoch, init_lr=0.01, decay=0.9, 
                             decay_epoch=500)
        adjust_learning_rate(optimizer_q_bet, epoch, init_lr=0.5, decay=0.9, 
                             decay_epoch=500)
        adjust_learning_rate(optimizer_q_tau, epoch, init_lr=0.01, decay=0.9, 
                             decay_epoch=500)
        adjust_learning_rate(optimizer_discr, epoch, init_lr=0.01, decay=0.9, 
                             decay_epoch=500)
        bs = x.size(0)
        # sample from q
        bet = sample_bet(LL, mu_b, logvar_b)  # LL*p
        logtau = sample_bet(LL, mu_logtau, logvar_logtau) # LL*1
        logsig2 = -logtau # LL*1
        Xbet = torch.einsum('ip,lp->li', (x,bet)) # dim=LL*n
        bet_tilde = bet.clone()
        Xbet_tilde = Xbet.clone() # dim=LL*n
        bets = torch.zeros([n_LD, LL, p]).to(device)
        taus = torch.zeros([n_LD, LL]).to(device)
        # sample from beta and tau by n_LD steps of MCMC. 
        if  True:# epoch>2110: # 
          for iii in range(n_LD):
            # sample lam which is the augmented variable of dimension p.
            lam = rWeibull(1, loga.exp(), logb.exp()) # dim = LL*p
            # sample tau and bet
            tau = ee.sample([LL]).to(device)/(
                    d+0.5*((y-Xbet_tilde).pow(2).sum(-1)+
                           nu.pow(2/alpha)*(bet_tilde.pow(2)*lam).sum(-1)
                          )
                    )
            LAM = torch.zeros([p,p]).to(device)
            LAM[range(p), range(p)] = lam.mean(0)
            Sig = (xpx+1*nu.pow(2/alpha)*LAM).inverse()
            gibbs_bet_cov = Sig/tau.mean() 
            lower = gibbs_bet_cov.cholesky()
            gibbs_bet_mean = torch.matmul(Sig, xpy).squeeze(-1) # dim=p*1
            bet_tilde = gibbs_bet_mean + torch.einsum('ij,lj->li',
                         (lower, torch.randn([LL]+list(mu_b.shape)).to(device))) 
            if torch.isnan(bet_tilde).sum()>0 or torch.isnan(tau).sum()>0\
                or torch.isnan(lam).sum()>0: # or (delta_z.abs()>100).sum()>0:
                print("NAN")
                NAN = True
                break
            bets[iii,:,:]=bet_tilde.clone()
            taus[iii,:] = tau.clone()
            Xbet_tilde = torch.einsum('ip,lp->li', (x,bet_tilde))
          bets = bets.view([-1,p])# dim = LL*p  
          taus = taus.view([-1]) # dim = LL
        # train generator and discriminator
        if True: # 
          # train q: make q get closer to \tilde q
          for _ in range(2):
              # bet
              logq_z = ( -0.5*logvar_b - 0.5/logvar_b.exp()*(
                      bets.detach()
                      -mu_b).pow(2)
                     ).mean(0).sum()
              gener_loss = -logq_z
              optimizer_q_bet.zero_grad()
              gener_loss.backward(retain_graph=True)
              optimizer_q_bet.step()
              # tau
              logq_tau = (-taus.log().mean(0) - 0.5*logvar_logtau - 
                    0.5/logvar_logtau.exp()*(taus.log()-mu_logtau).pow(2).mean(0)
                    ) 
              gener_loss_tau = -logq_tau
              optimizer_q_tau.zero_grad()
              gener_loss_tau.backward(retain_graph=True)
              optimizer_q_tau.step()
          # train discriminator
          if epoch>=000: 
            for _ in range(1): 
              bet = sample_bet(LL*n_LD, mu_b, logvar_b)  
              pars = torch.cat([bet.detach(), bets.detach()], dim=0)
              phat = discr(pars).sigmoid()
              d_loss = F.binary_cross_entropy(phat, label)#, reduction='none').mean()
              optimizer_discr.zero_grad()
              d_loss.backward(retain_graph=False)
              optimizer_discr.step() 
        # train the main model (MIVI)
        # likelihood
        Xbet = torch.einsum('ip,lp->li', (x,bets))
        logtau = taus.log().unsqueeze(-1)
        logsig2 = -logtau
        loglh = (-0.5*bs*logsig2 - 
             0.5*(y-Xbet).pow(2).sum(-1, keepdim=True)/logsig2.exp() ).mean()*n/bs
        # prior for bet  
        logprior_bet = (-nu*taus.pow(alpha/2)*bets.abs().pow(alpha).sum(-1)).mean(0)
        # Gam(c,d) prior for nu and tau
        logprior_tau = ((c-1)*logtau - d*logtau.exp()).mean()
        # q distribution
        # normal for bet
        logq_bet = (-0.5*logvar_b - 0.5/logvar_b.exp()*(bets-mu_b).pow(2).mean(0)
                    ).sum()
        # log-normal for tau
        logq_tau = (-logtau.mean(0) - 0.5*logvar_logtau - 
                    0.5/logvar_logtau.exp()*(logtau-mu_logtau).pow(2).mean(0)
                    ) 
        KL = discr(bets).mean()
        loss = logq_bet + logq_tau -\
                logprior_bet - logprior_tau - loglh + KL 
        optimizer_eta.zero_grad()
        loss.backward()
        optimizer_eta.step()
        losses.append(loss.item())
        if  epoch%100==1:# and (i+1) %250 ==0:
            print ("Epoch[{}/{}], logq_bet: {:.1f}, loss: {:.4f}, loglkh: {:.4f},\
                    loglhratio: {:.4f}" 
                   .format(epoch+1, num_epochs, logq_bet.item(), 
                           loss.item(), loglh.item(), 
                           KL.mean().item() 
                           )        
                   )
                   
plt.plot(losses)        



