'''
MIVI for logistic by incorporating Gibbs sampling updates of \beta and learn PG distributed w.
'''
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from matplotlib import pyplot as plt
import random
import numpy as np
import copy
import scipy
from scipy.io import loadmat
#from keras.utils import to_categorical
import pandas as pd
import gc
def log_sum_exp(value, dim=None, keepdim=False):
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                            dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)   
        
        
#device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample_bet(L, mu, log_var):
    std = torch.exp(log_var/2)
    eps = torch.randn([L]+list(mu.shape)).to(device)
    return mu + eps * std.unsqueeze(0)

#%% load data
# toy data
torch.manual_seed(0)
n=1000
p=4

Sigma = torch.eye(p).to(device)
Sigma[3,2]= 0.9
Sigma[2,3]= 0.9
Sigma[0,1]= -0.8
Sigma[1,0]= -0.8
lower = Sigma.cholesky()
#x=torch.randn(n, 5).to(device)
# correlated design matrix
x=torch.einsum('ip,np->ni', (lower, torch.randn(n, p).to(device)))
bet0 = torch.tensor([-2,-1,1,2]#
                    , dtype=torch.float).to(device)
yy = (x*bet0).sum(-1) + torch.randn(n).to(device)*0.2
m = torch.distributions.bernoulli.Bernoulli(torch.sigmoid(yy))
y = m.sample()
n, p = x.shape
#%% MIVI
# parameters
mu=torch.zeros([p]).to(device).requires_grad_(True) # mean if q(bet)
log_var=(torch.tensor(0.)+torch.zeros([p])).to(device).requires_grad_(True) # log var of q(bet)

class discriminator(nn.Module):
    def __init__(self, size=p, h_dim=int(1.5*p)):
        super(discriminator, self).__init__()
        self.fc11 = nn.Linear(size, h_dim).to(device)
        self.fc12 = nn.Linear(h_dim, h_dim+h_dim).to(device)
        self.fc13 = nn.Linear(h_dim+h_dim, h_dim).to(device)
        self.fc14 = nn.Linear(h_dim, 1).to(device)
    def forward(self, z):
        h = F.relu(self.fc11(z))
        h =F.relu(self.fc12(h))
        h = F.relu(self.fc13(h))
        h = self.fc14(h)
        return h

discr = discriminator()    



class rPG(nn.Module):
    def __init__(self,  eps_size=p, h_dim=int(1.2*p)):
        super(rPG, self).__init__()
        self.eps_size = eps_size
        self.transform = nn.Sequential(
            nn.Linear(eps_size+1, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1)
        )
    def forward(self, Xbet): # Xbet.shape = LL*n
        n = Xbet.size(-1)
        eps = torch.randn_like(Xbet.unsqueeze(-1).repeat(1,1,self.eps_size)).to(device)
        xbetesp= torch.cat([Xbet.unsqueeze(-1), eps], dim=-1)
        out = self.transform(xbetesp).exp()
        return out

rpg = rPG().to(device)     
optimizer_elbo = torch.optim.Adam(rpg.parameters(), lr=0.001)
optimizer_q = torch.optim.Adam([mu , log_var], lr=0.001)
optimizer_discr= torch.optim.Adam(discr.parameters(), lr=0.001)
NAN = False
TooBig = False
LL = 200 
bias = torch.ones([2*LL,1]).to(device)
I = torch.eye(p).to(device)
num_epochs = 250
losses=[]
discr_losses=[]
ee=torch.distributions.gamma.Gamma(1., 1.)
K=20
n_LD = 5
label = torch.tensor([0.01]*LL*n_LD + [0.99]*LL*n_LD).unsqueeze(-1).to(device) 

#%%
for epoch in range(num_epochs):
        bs = x.size(0)
        # sample from q
        bet = sample_bet(LL, mu, log_var)
        bet_tilde = bet.clone()
        Xbet = torch.einsum('ip,lp->li', (x,bet)) # dim=LL*n
        Xbet_tilde = Xbet.clone() # dim=LL*n
        bets = torch.zeros([n_LD, LL, p]).to(device)
        # sample from q_tilde
        if  True:# epoch>2110: # 
          for iii in range(n_LD):
            # NN to approximate PG:
            w = rpg(Xbet_tilde)
            # sample bet
            xw_sqrt = (x*w.sqrt()).mean(0)#.unsqueeze(-1)
            gibbs_bet_cov = (torch.einsum('ij,ik->jk', (xw_sqrt,xw_sqrt)) + I).inverse()
            lower = gibbs_bet_cov.cholesky()
            xpkappa = (x*(y-0.5).unsqueeze(-1)).sum(0)
            gibbs_bet_mean = (gibbs_bet_cov*xpkappa).sum(-1)
            bet_tilde = gibbs_bet_mean + torch.einsum('ij,lj->li',
                         (lower, torch.randn([LL]+list(mu.shape)).to(device))) 
            Xbet_tilde = torch.einsum('ip,lp->li', (x,bet_tilde)) # dim=LL*n
            
            if torch.isnan(bet_tilde).sum()>0: # or (delta_z.abs()>100).sum()>0:
                print("NAN")
                NAN = True
                break
            bets[iii,:,:]=bet_tilde.clone()
          bets = bets.view([-1,p])   
        if True: # start GAN
          # train q
          for _ in range(2):
              logq_z = ( -0.5*log_var - 0.5/log_var.exp()*(
                      bets
                      #bets.detach()
                      -mu).pow(2)
                     ).mean(0).sum()
              gener_loss = -logq_z
              optimizer_q.zero_grad()
              gener_loss.backward(retain_graph=True)
              optimizer_q.step()
          # train discriminator
          if epoch>=00: 
            for _ in range(1): 
              bet = sample_bet(LL*n_LD, mu, log_var)  
              pars = torch.cat([bet.detach(), bets.detach()], dim=0)
              phat = discr(pars).sigmoid()
              d_loss = F.binary_cross_entropy(phat, label)#, reduction='none').mean()
              optimizer_discr.zero_grad()
              d_loss.backward(retain_graph=False)
              optimizer_discr.step() 
        # train the main model
        Xbet = torch.einsum('ip,lp->li', (x,bets))
        py = Xbet.sigmoid()
        loglh = -F.binary_cross_entropy(py, y.repeat([LL*n_LD,1]), 
                                        reduction='none').mean(0).sum()*n/bs
        # N(0,1) prior for bet  
        logprior = -0.5*bets.pow(2).mean(0).sum()
        logq_bet = (-0.5*log_var -\
                    0.5/log_var.exp()*(bets-mu).pow(2).mean(0)).sum()
        if epoch>=000:
            KL = discr(bets).mean()
        else:
            KL = torch.tensor(0.1).to(device)
        loss = logq_bet - logprior - loglh + KL.mean()  
        optimizer_elbo.zero_grad()
        loss.backward()
        optimizer_elbo.step()
        losses.append(loss.item())
        if  epoch%50==1:# and (i+1) %250 ==0:
            print ("Epoch[{}/{}], logq_bet: {:.1f}, loss: {:.4f}, loglh: {:.4f},\
                    loglhratio: {:.4f}" 
                   .format(epoch+1, num_epochs, logq_bet.item(), 
                           loss.item(), loglh.item(), 
                           KL.mean().item() 
                           )        
                   )
                   
plt.plot(losses)  

print(np.corrcoef(bet.cpu().transpose(0,1).detach().numpy()))
print(np.corrcoef(bet_tilde.cpu().transpose(0,1).detach().numpy()))
