
import torch
import torchvision
import torch.nn as nn
import numpy as np
import scipy.stats
import scipy.io
import scipy.sparse
from scipy.io import loadmat
import pandas as pd
import matplotlib.pyplot as plt
import torch.distributions as td

from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from sklearn.ensemble import ExtraTreesRegressor
from sklearn.experimental import enable_iterative_imputer
from sklearn.linear_model import BayesianRidge
from sklearn.impute import IterativeImputer
from sklearn.impute import SimpleImputer


def mse(xhat,xtrue,mask, obs_std, true_std): # MSE function for imputations
    xhat = np.array(xhat)*obs_std/true_std
    xtrue = np.array(xtrue)*obs_std/true_std
    return np.mean(np.power(xhat-xtrue,2)[~mask])


data_set = 'banknote'

if data_set == 'breast':
    from sklearn.datasets import load_breast_cancer
    data = load_breast_cancer(True)[0]
if data_set == 'red':
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
    data = np.array(pd.read_csv(url, low_memory=False, sep=';'))
if data_set == 'white':
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
    data = np.array(pd.read_csv(url, low_memory=False, sep=';'))
if data_set == 'breast':
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/yeast/yeast.data"
data = np.array(pd.read_csv(url, low_memory=False, sep=r'\s+', usecols=[1,2,3,4,5,6,7,8]))[:,0:8]
if data_set == 'banknote':
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00267/data_banknote_authentication.txt"
    data = np.array(pd.read_csv(url, low_memory=False, sep=','))[:,0:4]
if data_set == 'concrete':
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls"
    data = np.array(pd.read_excel(url))[:,0:9]
print(data.shape)

avg_mse = []
for run in range(0, 5):
    xfull = np.copy(data)#(data - np.mean(data,0))/np.std(data,0)
    n = xfull.shape[0] # number of observations
    p = xfull.shape[1] # number of features


    # We will remove uniformy at random 50% of the data. This corresponds to a *missing completely at random (MCAR)* scenario.

    perc_miss = 0.5 # 50% of missing data
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int), replace=False)
    masked_data = np.copy(data)
    masked_data = masked_data.flatten()
    observed = np.ones_like(masked_data)
    masked_data[miss_pattern] = 0.0
    observed[miss_pattern] = 0.0
    masked_data = masked_data.reshape([n,p])
    observed = observed.reshape([n,p])
    obs_prob = np.mean(observed, 0)
    obs_mean = np.mean(masked_data, 0)/obs_prob
    obs_std = np.maximum(  (np.mean(((masked_data - obs_mean)*observed)**2, 0)/obs_prob)**0.5,1e-6 * np.ones_like(obs_mean))
    true_std = np.std(data, 0)
    print(obs_std, np.std(data,0))

    xfull = (xfull - obs_mean)/obs_std
    


   
    xmiss = np.copy(xfull)
    xmiss_flat = xmiss.flatten()
    
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing


    # A simple way of imputing the incomplete data is to replace the missing values by zeros. This x_hat0 is what will be fed to our encoder.



    xhat_0 = np.copy(xmiss)
    xhat_0[np.isnan(xmiss)] = 0


    # # Hyperparameters



    h = 128 # number of hidden units in (same for all MLPs)
    d = 10 # dimension of the latent space
    K = 20 # number of IS during training


    # In[ ]:





    p_z = td.Independent(td.Normal(loc=torch.zeros(d).cuda(),scale=torch.ones(d).cuda()),1)


    decoder = nn.Sequential(
        torch.nn.Linear(d, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
    )




    encoder = nn.Sequential(
        torch.nn.Linear(p, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, h),
        torch.nn.ReLU(),
        torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
    )




    encoder.cuda() # we'll use the GPU
    decoder.cuda()



    def miwae_loss(iota_x,mask):
      batch_size = iota_x.shape[0]
      out_encoder = encoder(iota_x)
      q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)
      
      zgivenx = q_zgivenxobs.rsample([K])
      zgivenx_flat = zgivenx.reshape([K*batch_size,d])
      
      out_decoder = decoder(zgivenx_flat)
      all_means_obs_model = out_decoder[..., :p]
      all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
      all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3
      
      data_flat = torch.Tensor.repeat(iota_x,[K,1]).reshape([-1,1])
      tiledmask = torch.Tensor.repeat(mask,[K,1])
      
      all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
      all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K*batch_size,p])
      
      logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([K,batch_size])
      logpz = p_z.log_prob(zgivenx)
      logq = q_zgivenxobs.log_prob(zgivenx)
      
      neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))
      
      return neg_bound


    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),lr=1e-3)



    def miwae_impute(iota_x,mask,L):
      batch_size = iota_x.shape[0]
      out_encoder = encoder(iota_x)
      q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)
      
      zgivenx = q_zgivenxobs.rsample([L])
      zgivenx_flat = zgivenx.reshape([L*batch_size,d])
      
      out_decoder = decoder(zgivenx_flat)
      all_means_obs_model = out_decoder[..., :p]
      all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
      all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3
      
      data_flat = torch.Tensor.repeat(iota_x,[L,1]).reshape([-1,1]).cuda()
      tiledmask = torch.Tensor.repeat(mask,[L,1]).cuda()
      
      all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
      all_log_pxgivenz = all_log_pxgivenz_flat.reshape([L*batch_size,p])
      
      logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([L,batch_size])
      logpz = p_z.log_prob(zgivenx)
      logq = q_zgivenxobs.log_prob(zgivenx)
      
      xgivenz = td.Independent(td.StudentT(loc=all_means_obs_model, scale=all_scales_obs_model, df=all_degfreedom_obs_model),1)

      imp_weights = torch.nn.functional.softmax(logpxobsgivenz + logpz - logq,0) # these are w_1,....,w_L for all observations in the batch
      xms = xgivenz.sample().reshape([L,batch_size,p])
      xm=torch.einsum('ki,kij->ij', imp_weights, xms) 
      

      
      return xm

    def weights_init(layer):
      if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
       


    # In[74]:


    miwae_loss_train=np.array([])
    mse_train=np.array([])
    mse_train2=np.array([])
    bs = 64 # batch size
    n_epochs = 32002
    xhat = np.copy(xhat_0) # This will be out imputed data matrix

    encoder.apply(weights_init)
    decoder.apply(weights_init)

    for ep in range(1,n_epochs):
      perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
      batches_data = np.array_split(xhat_0[perm,], n/bs)
      batches_mask = np.array_split(mask[perm,], n/bs)
      for it in range(len(batches_data)):
        optimizer.zero_grad()
        encoder.zero_grad()
        decoder.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float().cuda()
        b_mask = torch.from_numpy(batches_mask[it]).float().cuda()
        loss = miwae_loss(iota_x = b_data,mask = b_mask)
        loss.backward()
        optimizer.step()
      if ep % 100 == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(iota_x = torch.from_numpy(xhat_0).float().cuda(),mask = torch.from_numpy(mask).float().cuda()).cpu().data.numpy())) # Gradient step      
        
        ### Now we do the imputation
        
        xhat[~mask] = miwae_impute(iota_x = torch.from_numpy(xhat_0).float().cuda(),mask = torch.from_numpy(mask).float().cuda(),L=10).cpu().data.numpy()[~mask]
        err = np.array([mse(xhat,xfull,mask, obs_std, true_std)])
        mse_train = np.append(mse_train,err,axis=0)
        print('Imputation MSE  %g' %err)
        print('-----')
    with torch.no_grad():
        batches_full = np.array_split(xfull[perm,], n/bs)
        err = 0
        num = 0
        for it in range(len(batches_data)):
            b_data = torch.from_numpy(batches_data[it]).float().cuda()
            b_full = torch.from_numpy(batches_full[it]).float().cuda()
            b_hat = b_data.clone().cpu().numpy()
            b_mask = torch.from_numpy(batches_mask[it]).float().cuda()
            b_hat[~b_mask.bool().cpu()] = miwae_impute(iota_x = b_data,mask = b_mask,L=10000).cpu().data.numpy()[~b_mask.bool().cpu()]
            num += len(b_hat)
            err = (float(num - len(b_hat)) *err + float(len(b_hat))*np.array([mse(b_hat,b_full.cpu(),b_mask.cpu().bool(), obs_std, true_std)]))/ float(num)
        print('Final Imputation MSE  %g' %err)
        print('-----')
        avg_mse.append(err)

print("NMSE Results: ", np.mean(avg_mse), np.std(avg_mse))


