
import pickle
import numpy as np
import matplotlib.pyplot as plt
from utilities import currupt_data, simulate_xor_data
from sklearn.experimental import enable_iterative_imputer 
from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer
from architectures import sterlizedNeuron
from sklearn.utils import shuffle
from sklearn.metrics import roc_auc_score
from numpy.random import seed
from tensorflow.random import set_seed


################################### CONFIGS ###################################


save_path = '/home/' # Change this to your local path for saving results
runs = 100   # Number of experiment repetitions
steps = 100     # Steps for evaluating the the learning curve, steps*10=epochs
conditions = ['MCAR','MAR','MNAR']
missing_rate = 0.5 # Selecti the missing sample rate. 0.3 or 0.5 for reproducing the results in the paper


###############################################################################


for missing_data_mechanism in conditions:
    
    
    ############################### PLACE HOLDERS #################################
    
    params = {}
    params['NN']  = []
    params['PROMISSING']  = []
    params['mPROMISSING']  = []
    s = np.zeros([runs], dtype=np.int32)
    auc_NN_full = np.zeros([runs,steps,5])
    auc_NN_imputed = np.zeros([runs,steps,4])
    auc_prom = np.zeros([runs,steps,2])
    auc_mprom = np.zeros([runs,steps,2])
    
    
    ############################### RUNS #########################################
    
    
    s = np.random.randint(100000, size=[runs], dtype=np.int32)
    
    for r in range(runs):
            
        seed(s[r])
        X,y = simulate_xor_data(n_samples=1000, noise = 0.25)
        X,  y = shuffle(X, y)
        Xc = currupt_data(X, method=missing_data_mechanism, missing_rate=missing_rate) 
        
        X_train = X[0:500,:]
        X_test = X[500:,:]
        y_train = y[0:500]
        y_test = y[500:]
        Xc_train = Xc[0:500,:]
        Xc_test = Xc[500:,:]
            
        seed(s[r])
        set_seed(s[r])
        NN = sterlizedNeuron(use_nanDense=False)
        params['NN'].append(NN.model.get_weights())
        NN_zero = sterlizedNeuron(use_nanDense=False)
        NN_mean = sterlizedNeuron(use_nanDense=False)
        NN_knn = sterlizedNeuron(use_nanDense=False)
        NN_mice = sterlizedNeuron(use_nanDense=False)
        
        imputer =  SimpleImputer(missing_values=np.nan, strategy='constant')
        Xc_train_zero_imputed = imputer.fit_transform(Xc_train)
        Xc_test_zero_imputed = imputer.transform(Xc_test)
        imputer =  SimpleImputer(missing_values=np.nan, strategy='mean')
        Xc_train_mean_imputed = imputer.fit_transform(Xc_train)
        Xc_test_mean_imputed = imputer.transform(Xc_test)
        imputer =  KNNImputer(missing_values=np.nan)
        Xc_train_knn_imputed = imputer.fit_transform(Xc_train)
        Xc_test_knn_imputed = imputer.transform(Xc_test)
        imputer =  IterativeImputer(missing_values=np.nan, sample_posterior=True)
        Xc_train_mice_imputed = imputer.fit_transform(Xc_train)
        Xc_test_mice_imputed = imputer.transform(Xc_test)
        
        
        y_hat = NN_zero.predict(Xc_test_zero_imputed)
        auc_NN_imputed[r,0,0] = roc_auc_score(y_test, y_hat)
        y_hat = NN_mean.predict(Xc_test_mean_imputed)
        auc_NN_imputed[r,0,1] = roc_auc_score(y_test, y_hat)
        y_hat = NN_knn.predict(Xc_test_knn_imputed)
        auc_NN_imputed[r,0,2] = roc_auc_score(y_test, y_hat)
        y_hat = NN_mice.predict(Xc_test_mice_imputed)
        auc_NN_imputed[r,0,3] = roc_auc_score(y_test, y_hat)
        
        y_hat = NN.predict(X_test)
        auc_NN_full[r,0,0] = roc_auc_score(y_test, y_hat)
        y_hat = NN_zero.predict(X_test)
        auc_NN_full[r,0,1] = roc_auc_score(y_test, y_hat)
        y_hat = NN_mean.predict(X_test)
        auc_NN_full[r,0,2] = roc_auc_score(y_test, y_hat)
        y_hat = NN_knn.predict(X_test)
        auc_NN_full[r,0,3] = roc_auc_score(y_test, y_hat)
        y_hat = NN_mice.predict(X_test)
        auc_NN_full[r,0,4] = roc_auc_score(y_test, y_hat)
        
        print(missing_data_mechanism + ': NN => Run:%d, step:%d' %(r,0))
        
        for i in range(1,steps):
            
            NN = NN.fit(X_train, y_train, epochs=10)
            params['NN'].append(NN.model.get_weights())
            y_hat = NN.predict(X_test)
            auc_NN_full[r,i,0] = roc_auc_score(y_test, y_hat)
            
            NN_zero = NN_zero.fit(Xc_train_zero_imputed, y_train, epochs=10)
            y_hat = NN_zero.predict(X_test)
            auc_NN_full[r,i,1] = roc_auc_score(y_test, y_hat)
            y_hat = NN_zero.predict(Xc_test_zero_imputed)
            auc_NN_imputed[r,i,0] = roc_auc_score(y_test, y_hat)
            
            NN_mean = NN_mean.fit(Xc_train_mean_imputed, y_train, epochs=10)
            y_hat = NN_mean.predict(X_test)
            auc_NN_full[r,i,2] = roc_auc_score(y_test, y_hat)
            y_hat = NN_mean.predict(Xc_test_mean_imputed)
            auc_NN_imputed[r,i,1] = roc_auc_score(y_test, y_hat)
            
            NN_knn = NN_knn.fit(Xc_train_knn_imputed, y_train, epochs=10)
            y_hat = NN_knn.predict(X_test)
            auc_NN_full[r,i,3] = roc_auc_score(y_test, y_hat)
            y_hat = NN_knn.predict(Xc_test_knn_imputed)
            auc_NN_imputed[r,i,2] = roc_auc_score(y_test, y_hat)
            
            NN_mice = NN_mice.fit(Xc_train_mice_imputed, y_train, epochs=10)
            y_hat = NN_mice.predict(X_test)
            auc_NN_full[r,i,4] = roc_auc_score(y_test, y_hat)
            y_hat = NN_mice.predict(Xc_test_mice_imputed)
            auc_NN_imputed[r,i,3] = roc_auc_score(y_test, y_hat)
            
            print(missing_data_mechanism + ': NN => Run:%d, step:%d' %(r,i))
        
        seed(s[r])
        set_seed(s[r])
        PROMISSING = sterlizedNeuron() 
        params['PROMISSING'].append(PROMISSING.model.get_weights()) 
        y_hat = PROMISSING.predict(Xc_test)
        auc_prom[r,0,1] = roc_auc_score(y_test, y_hat)
        y_hat = PROMISSING.predict(X_test)
        auc_prom[r,0,0] = roc_auc_score(y_test, y_hat)
        print(missing_data_mechanism + ': PROMISSING => Run:%d, step:%d, AUC_Full:%f, AUC_Missing:%f'
              %(r,0,auc_prom[r,0,0],auc_prom[r,0,1]))
        
        for i in range(1,steps):
            PROMISSING = PROMISSING.fit(Xc_train, y_train, epochs=10)
            params['PROMISSING'].append(PROMISSING.model.get_weights())
            y_hat = PROMISSING.predict(Xc_test)
            auc_prom[r,i,1] = roc_auc_score(y_test, y_hat)
            y_hat = PROMISSING.predict(X_test)
            auc_prom[r,i,0] = roc_auc_score(y_test, y_hat)
            print(missing_data_mechanism + ': PROMISSING => Run:%d, step:%d, AUC_Full:%f, AUC_Missing:%f'
                  %(r,i,auc_prom[r,i,0],auc_prom[r,i,1]))
            
        seed(s[r])
        set_seed(s[r])
        mPROMISSING = sterlizedNeuron(use_c=True)  
        params['mPROMISSING'].append(mPROMISSING.model.get_weights()) 
        y_hat = mPROMISSING.predict(Xc_test)
        auc_mprom[r,0,1] = roc_auc_score(y_test, y_hat)
        y_hat = mPROMISSING.predict(X_test)
        auc_mprom[r,0,0] = roc_auc_score(y_test, y_hat)
        print(missing_data_mechanism + ': mPROMISSING => Run:%d, step:%d, AUC_Full:%f, AUC_Missing:%f'
              %(r,0,auc_mprom[r,0,0],auc_mprom[r,0,1]))
        
        for i in range(1,steps):
            mPROMISSING = mPROMISSING.fit(Xc_train, y_train, epochs=10)
            params['mPROMISSING'].append(mPROMISSING.model.get_weights())
            y_hat = mPROMISSING.predict(Xc_test)
            auc_mprom[r,i,1] = roc_auc_score(y_test, y_hat)
            y_hat = mPROMISSING.predict(X_test)
            auc_mprom[r,i,0] = roc_auc_score(y_test, y_hat)
            print(missing_data_mechanism + ': mPROMISSING => Run:%d, step:%d, AUC_Full:%f, AUC_Missing:%f'
              %(r,i,auc_mprom[r,i,0],auc_mprom[r,i,1]))
            
            
        with open(save_path + 'Experiment1_'+ missing_data_mechanism +'_'+ str(missing_rate) + 
                  '_results.pkl','wb') as file:
                pickle.dump({'auc_NN_full':auc_NN_full,'auc_NN_imputed':auc_NN_imputed, 
                             'auc_prom':auc_prom,'auc_mprom':auc_mprom,
                             'params':params,'seed':s}, file)
        
        
############################# PLOTTING LEARNING CURVES ########################  


with_gain = False

fig, ax = plt.subplots(2,3, sharex=True, sharey=True, dpi=150)

for m, mdm in enumerate(conditions):
        
    with open(save_path + 'Experiment1_'+ mdm +'_'+ str(missing_rate) + 
              '_results.pkl','rb') as file:
        data = pickle.load(file)
    
    auc_full = data['auc_NN_full']
    auc_nn = data['auc_prom']
    auc_mnn = data['auc_mprom']
        
    if with_gain:
        with open(save_path + 'Experiment1_GAIN_'+ mdm +'_'+ str(missing_rate) + 
                  '_results.pkl','rb') as file:
            gain_data = pickle.load(file)
        auc_full = np.concatenate([auc_full, gain_data['auc_NN_full']], axis=2)
        
    x = np.arange(steps)*10
    
    colors = ['#000000','#B22222','#DAA520','#008080','#BA55D3','#778899']
    names = ['Bayes', 'Zero', 'Mean', 'KNN', 'MICE', 'GAIN']
    
    for i in range(auc_full.shape[2]):
        if i>0:
            linestyle = '--'
        else:
            linestyle = '-'
        ax[0,m].plot(x, np.mean(auc_full[:,0:steps,i],axis=0),
                   label=names[i], color=colors[i], linestyle=linestyle)
        ax[0,m].fill_between(x, np.mean(auc_full[:,0:steps,i],axis=0) -
                         np.std(auc_full[:,0:steps,i],axis=0), 
                         np.mean(auc_full[:,0:steps,i],axis=0) +
                         np.std(auc_full[:,0:steps,i],axis=0), 
                         alpha=0.05, color=colors[i], linestyle=linestyle)
    
    ax[0,m].plot(x, np.mean(auc_nn[:,0:steps,0],axis=0), 
               color='#D2691E', label='PROMISSING')
    ax[0,m].fill_between(x, np.mean(auc_nn[:,0:steps,0],axis=0) -
                     np.std(auc_nn[:,0:steps,0],axis=0), 
                     np.mean(auc_nn[:,0:steps,0],axis=0) +
                     np.std(auc_nn[:,0:steps,0],axis=0), 
                     alpha=0.05, color='#D2691E')
    
    ax[0,m].plot(x, np.mean(auc_mnn[:,0:steps,0],axis=0), color='#6495ED', 
               label='mPROMISSING')
    ax[0,m].fill_between(x, np.mean(auc_mnn[:,0:steps,0],axis=0) -
                     np.std(auc_mnn[:,0:steps,0],axis=0), 
                     np.mean(auc_mnn[:,0:steps,0],axis=0) +
                     np.std(auc_mnn[:,0:steps,0],axis=0), 
                     alpha=0.05, color='#6495ED')
    ax[0,m].grid(linestyle='--', linewidth=0.5)
    for spine in ax[0,m].spines.values():
        spine.set_visible(False)
    
        
    auc_full = data['auc_NN_imputed']
    if with_gain:
        auc_full = np.concatenate([auc_full, gain_data['auc_NN_imputed']], axis=2)
    colors = ['#B22222','#DAA520','#008080','#BA55D3','#778899']
    imp_name = ['Zero', 'Mean', 'KNN', 'MICE', 'GAIN']
    for i in range(auc_full.shape[2]):
        ax[1,m].plot(x, np.mean(auc_full[:,0:steps,i],axis=0), color=colors[i], 
                   linestyle='--', label=imp_name[i])
        ax[1,m].fill_between(x, np.mean(auc_full[:,0:steps,i],axis=0) -
                         np.std(auc_full[:,0:steps,i],axis=0), 
                         np.mean(auc_full[:,0:steps,i],axis=0) +
                         np.std(auc_full[:,0:steps,i],axis=0), 
                         color=colors[i], alpha=0.05, linestyle='--')
    
    ax[1,m].plot(x, np.mean(auc_nn[:,0:steps,1],axis=0), color='#D2691E', label='PROMISSING')
    ax[1,m].fill_between(x, np.mean(auc_nn[:,0:steps,1],axis=0) -
                     np.std(auc_nn[:,0:steps,1],axis=0), 
                     np.mean(auc_nn[:,0:steps,1],axis=0) +
                     np.std(auc_nn[:,0:steps,1],axis=0), 
                     alpha=0.05, color='#D2691E')
    
    ax[1,m].plot(x, np.mean(auc_mnn[:,0:steps,1],axis=0), color='#6495ED', label='mPROMISSING')
    ax[1,m].fill_between(x, np.mean(auc_mnn[:,0:steps,1],axis=0) -
                     np.std(auc_mnn[:,0:steps,1],axis=0), 
                     np.mean(auc_mnn[:,0:steps,1],axis=0) +
                     np.std(auc_mnn[:,0:steps,1],axis=0), 
                     alpha=0.05, color='#6495ED')
    
    ax[1,m].set_xlabel('Epochs')
    ax[1,m].grid(linestyle='--', linewidth=0.5)
    
    for spine in ax[1,m].spines.values():
        spine.set_visible(False)

ax[0,0].legend(fontsize=7)
ax[1,0].set_xlabel('Epochs')
ax[1,1].set_xlabel('Epochs')
ax[1,2].set_xlabel('Epochs')
ax[0,0].set_ylabel('AUC')
ax[1,0].set_ylabel('AUC')
ax[0,0].set_title('MCAR')
ax[0,1].set_title('MAR')
ax[0,2].set_title('MNAR')

plt.savefig(save_path + 'Experiment1_' + str(missing_rate) + '_results.png', dpi=300)


###############################################################################
