import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
import math
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
import collections
import torch.nn.functional as F
import shap
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"

# GUIDE Bayes feature selection performance
auc_shap  = np.zeros((10,49))
auc_guide = np.zeros((10,49))
iterr = 0
for m_n in range(1,11):
    pred = []
    c_true = []
    for j in range(10):
         pred.append(loadmat('guide_importance/data_'+str(m_n)+'/model'+str(j)+'/vis'+'.mat')['pred'])
         c_true.append(loadmat('guide_importance/data_'+str(m_n)+'/model'+str(j)+'/vis'+'.mat')['true_class'])
    
    p = np.concatenate(tuple(pred),1)
    c = np.concatenate(tuple(c_true),1)
    
    for i in range(len(p)):
        fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( c[0,:]  , p[i,:], pos_label=1)
        auc_guide[iterr,i] = metrics.auc(fpr_test, tpr_test)
    iterr+=1
    
#some example data
x= np.array(range(5,246,5))
y = auc_guide
#some confidence interval
ci = 2.262 * np.std(y,0)/np.sqrt(iterr)

fig, ax = plt.subplots()
ax.plot(x,np.mean(y,0),linewidth=10, color='b')
ax.fill_between(x, (np.mean(y,0)-ci), (np.mean(y,0)+ci), color='b', alpha=.2)

ax.set_ylim(0.6, 0.8)





# SHAP feature selection performance
iterr = 0
for m_n in range(1,11):
    pred = []
    c_true = []
    for j in range(10):
         pred.append(loadmat('shap_importance/data_'+str(m_n)+'/model'+str(j)+'/vis'+'.mat')['pred'])
         c_true.append(loadmat('shap_importance/data_'+str(m_n)+'/model'+str(j)+'/vis'+'.mat')['true_class'])
    
    p = np.concatenate(tuple(pred),1) # Predicted probabilities
    c = np.concatenate(tuple(c_true),1) # True class labels
    
    for i in range(len(p)): # Loop over for differbt values of K in top K feature selection.
        fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( c[0,:]  , p[i,:], pos_label=1)
        auc_shap[iterr,i] = metrics.auc(fpr_test, tpr_test)
    iterr+=1
    

x= np.array(range(5,246,5))
y = auc_shap
#some confidence interval
ci = 2.262 * np.std(y,0)/np.sqrt(iterr)

#fig, ax = plt.subplots()
ax.plot(x,np.mean(y,0),color='r',linewidth=10,linestyle='dashed')
ax.fill_between(x, (np.mean(y,0)-ci), (np.mean(y,0)+ci), color='r', alpha=.2)
ax.tick_params(axis='both', which='major', labelsize=50)
ax.set_xlabel("K",fontsize=20,fontweight='bold')
ax.set_ylabel("AUC",fontsize=20,fontweight='bold')
ax.xaxis.set_tick_params(labelsize=50)
ax.yaxis.set_tick_params(labelsize=50)
legend_properties = {'fontsize':50,'weight':'bold'}
ax.legend(['GUIDE Importance','K-SHAP Importance'],fontsize=50)
plt.xticks(weight = 'bold')
plt.yticks(weight = 'bold')