# -*- coding: utf-8 -*-
"""ICLR_dataset1.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/16-b9ueadC6f1yA-vmBn1YjTeNCG4j74P
"""

! pip install shap

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import models as nam_models
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import sklearn
import pandas as pd
import shap
import sklearn

"""##Loading the data"""

orgdata=pd.read_csv("give_me_some_credit.csv")
orgdata=orgdata.dropna(axis=0)
orgdata.head()
data=orgdata
data.head()

x=data.iloc[:,2:data.shape[1]]
y=data.iloc[:,1]
x.insert(data.shape[1]-2, 'label', y)
x.head()
x=np.array(x)

x[:,[0,2]]=x[:,[2,0]]
x[:,[1,8]]=x[:,[8,1]]
x[:,[2,6]]=x[:,[6,2]]
x[:,[3,4]]=x[:,[4,3]]
x[:,[4,9]]=x[:,[9,4]]

"""##Manipulating the data"""

x1=x

x1 = np.delete(x1, np.where
    (x1[:,0] >= 20), axis=0)
x1 = np.delete(x1, np.where
    (x1[:,1] >= 20), axis=0)
x1 = np.delete(x1, np.where
    (x1[:,2] >= 20), axis=0)

x1[:,0][x1[:,0]>=5]=5
x1[:,1][x1[:,1]>=5]=5
x1[:,2][x1[:,2]>=5]=5

x1[:,3][(x1[:,3]>=0)&(x1[:,3]<2500)]=5
x1[:,3][(x1[:,3]>=2500)&(x1[:,3]<5000)]=4
x1[:,3][(x1[:,3]>=5000)&(x1[:,3]<7500)]=3
x1[:,3][(x1[:,3]>=7500)&(x1[:,3]<10000)]=2
x1[:,3][(x1[:,3]>=10000)&(x1[:,3]<50000)]=1
x1[:,3][x1[:,3]>=50000]=0

x1[:,4][x1[:,4]>=5]=5

X=x1[:,0:len(x1[0])-1]
Y=x1[:,-1]

tf.compat.v1.reset_default_graph()
#the change of the structure of NAM is in kwargs)
#in this demo, first three feature are strong monotonicity, thereby, they are combined in a DNN.
Number_of_DNN=8
Number_of_Unit=0
Trainable=True
Use_Shallow=False
Model=nam_models.NAM(Number_of_DNN,Number_of_Unit,Trainable,Use_Shallow,feature_dropout = 0.0,dropout = 0.0,kwargs=[3,1,1,1,1,1,1,1])
Model(X)

Model.summary()

"""##Loading the trained-out model"""

Model.load_weights("my_checkpoint")

"""##Calculating GMShap and Shap value

###Function of calculating GMShap
"""

from itertools import  product
from scipy.special import comb
import copy
import numpy as np
def myshap(spli,X,basevalue):

  #generating combinations
  def create_mask_arrays(k):
    mask = np.array(list(product(range(2), repeat=k)))
    mask = mask[~np.all(mask == 1, axis=1)]
    return mask

  mask=create_mask_arrays(len(spli))
  mask2=mask
  mask2=np.insert(mask2, len(mask), values=np.ones(len(spli)), axis=0)
  mask1=np.repeat(mask,spli,1)
  mask1 = np.insert(mask1, len(mask), values=np.ones(sum(spli)), axis=0)
  masktemp=np.repeat(mask,spli,1)

  #calculating the permutation probablity of shap
  def calculate_weight(mask_row):
    M = no_of_features = len(mask_row)
    z = no_of_masked_feature = np.sum(mask_row)
    weight = 1 /  (  comb(M,z) * (M-z)   )
    return weight

  S_with_zero = mask1*X
  S_full = (S_with_zero == 0)*basevalue + S_with_zero
  weights =np.apply_along_axis(calculate_weight, 1, mask)

  all_value=[]
  g=[]
  phi_3=[]
  for i in range(len(spli)):
    mmk=[]
    mask3=mask
    mask3 = np.insert(mask3, len(mask), values=np.ones(len(spli)), axis=0)
    mask3[:,i]=1
    mask3=np.repeat(mask3,spli,1)
    S_with_zero_1 = mask3*X
    S_full_1 = (S_with_zero_1 == 0)*basevalue + S_with_zero_1
    all_value.append(np.array(Model(S_full_1))-np.array(Model(S_full)))

    #Gshap if the pairwise monotonic features are greater than 1
    if(i==0 and spli[0]>1):
      inner_shap=[]
      for k in range(spli[0]-1):
        now_modify=k
        S_full_phi_inner=copy.deepcopy(S_full_1)
        for m in range(now_modify+1,spli[0]):
          S_full_phi_inner[:,now_modify]+=S_full_phi_inner[:,m]
          S_full_phi_inner[:,m]=0
        inner_shap.append(S_full_phi_inner)
      for k in range(spli[0]-1):
        g0=np.array(Model(inner_shap[k]))-np.array(Model(S_full))
        g.append(g0)
      g=np.array(g)

  shapp=[]
  shapp_g=[]
  for i in range(len(spli)):
    temp=0
    temp_g=np.zeros(spli[0]-1)
    for j in range(len(weights)):
      if (mask[j][i]==0):
        temp+=all_value[i][j]*weights[j]
        if (i==0 and spli[0]>1):
          for k in range(spli[0]-1):
            temp_g[k]+=g[k][j]*weights[j]
    shapp.append(temp)
    if(i==0 and spli[0]>1):
      for k in range(spli[0]-1):
        shapp_g.append(temp_g[k])

  #return the (shap) or (shap and g_shap)
  if(spli[0]==1):
    return np.array(shapp)
  else:
    return np.array(shapp),np.array(shapp_g)

"""###Calculating at a single datapoint"""

datapoint=np.array([2,2,5,4,4,11,1.00946276,0,30,0.57013026])
#this is the datapoint used in paper
for i in range(1):
  GM_shap=np.zeros(10)
  v1=datapoint[0]
  v2=datapoint[1]
  v3=datapoint[2]
  if((v1==0 and v2==0) or (v2==0 and v3==0) or(v1==0 and v3==0) or (v1==v2 and v2==v3)):
    continue
  shapp,shapp_g=myshap([3,1,1,1,1,1,1,1],datapoint,[0,0,0,0,0,0,0,0,0,0])
  ORG_shap=myshap([1,1,1,1,1,1,1,1,1,1],datapoint,[0,0,0,0,0,0,0,0,0,0])
  shapp_g3=shapp[0]
  if(v2+v3!=0):
    matri=[[v1/(v1+v2+v3),0,0],
          [v2/(v1+v2+v3)-v2/(v2+v3),v2/(v2+v3),0],
          [v3/(v1+v2+v3)-v3/(v2+v3),(v3/(v2+v3))-1,1]]
    shp_1_2_3=np.dot(matri,[shapp_g[0],shapp_g[1],shapp_g3])
    shapp_1=shp_1_2_3[0]
    shapp_2=shp_1_2_3[1]
    shapp_3=shp_1_2_3[2]
  elif(v2+v3==0 and v1!=0):
    matri=[[v1/(v1+v2+v3),0,0]]
    shp_1_2_3=np.dot(matri,[shapp_g[0],shapp_g[1],shapp_g3])
    shapp_1=shp_1_2_3[0]
    shapp_2=0
    shapp_3=0
  else:
    shapp_1=0
    shapp_2=0
    shapp_3=0
  GM_shap[0]=shapp_1
  GM_shap[1]=shapp_2
  GM_shap[2]=shapp_3
  GM_shap[3:]=shapp[1:]

  #Please note that:
  #the sequence of x_1 to x_3 in the code is pastdue30-60;;pastdue60-90;pastdue90+
  #the sequence of x_1 to x_3 in the paper is  pastdue90+;pastdue60-90;pastdue30-60

  print("GM_shap",GM_shap)
  print("Shap",ORG_shap)
  print("-----------------")

"""##A sample visulization

###the single datapoint for GMShap
"""

explainer1 = shap.Explainer(Model,np.array([[0,0,0,0,0,0,0,0,0,0]]))
explainer1.feature_names=['x_1', 'x_2', 'x_3','x_4','x_5', 'x_6', 'x_7', 'x_8', 'x_9', 'x_10']

shap_values1 = explainer1([datapoint])

# visualize the first prediction's explanation
# the x_1 to x_3 in paper is pastdue90+;pastdue60-90;pastdue30-60
# the x_1 to x_3 in code is pastdue30-60;pastdue60-90;pastdue90+
# to fit the sequence in paper, we swap x_1 and x_3 in visulization part.

shap_values1.data=np.array([[5,2,2,4,4,11,1.00946276,0,30,0.57013026]])
temp=GM_shap[2]
GM_shap[2]=GM_shap[0]
GM_shap[0]=temp
shap_values1.values=[GM_shap]

shap.plots.bar(shap_values1[0],show_data=True)

"""###the single datapoint for Shap"""

explainer2= shap.Explainer(Model,np.array([[0,0,0,0,0,0,0,0,0,0]]))
explainer2.feature_names=['x_1', 'x_2', 'x_3','x_4','x_5', 'x_6', 'x_7', 'x_8', 'x_9', 'x_10']
shap_values2 = explainer2([datapoint])

# the x_1 to x_3 in paper is pastdue90+;pastdue60-90;pastdue30-60
# the x_1 to x_3 in code is pastdue30-60;pastdue60-90;pastdue90+
# to fit the sequence in paper, we swap x_1 and x_3 in visulization part.

shap_values2.data=np.array([[5,2,2,4,4,11,1.00946276,0,30,0.57013026]])
temp=ORG_shap[2]
ORG_shap[2]=ORG_shap[0]
ORG_shap[0]=temp
shap_values2.values=[ORG_shap]

shap.plots.bar(shap_values2[0],show_data=True)