import aix360
import numpy as np
from keras.models import model_from_json
from PIL import Image
from matplotlib import pyplot as plt
from skimage.color import gray2rgb
import argparse
from keras.optimizers import SGD
import tensorflow as tf

import os
import sys

### Fill in path to AIX360
aix360_path = ***FILL IN PATH***

os.chdir(os.path.join(aix360_path, r'aix360\algorithms\contrastive'))
sys.path.insert(0, os.path.join(os.getcwd()))
from CEM_pp_path import CEMExplainer_pp_path
from aix360.algorithms.contrastive import CEMExplainer
from aix360.algorithms.contrastive.classifiers import KerasClassifier

model_path = os.path.join(aix360_path, r'aix360\models\CEM')

from aix360.datasets import MNISTDataset

# image to predict on and explain
image_id = 18

# load MNIST data and normalize it in the range [-0.5, 0.5]
data = MNISTDataset()

def load_model(model_json_file, model_wt_file):

    # read model json file
    with open(model_json_file, 'r') as f:
        model = model_from_json(f.read())

    # read model weights file
    model.load_weights(model_wt_file)

    return model

def fn(correct, predicted):
    return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
                                                    logits=predicted)

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

# load MNIST model using its json and wt files
mnist_model = load_model(os.path.join(model_path, 'mnist.json'), os.path.join(model_path, 'mnist'))
mnist_model.compile(loss=fn, optimizer=sgd, metrics=['accuracy'])

# print model summary
mnist_model.summary()

# load the trained convolutional autoencoder model
ae_model = load_model(os.path.join(model_path, 'mnist_AE_1_decoder.json'),
                      os.path.join(model_path, 'mnist_AE_1_decoder.h5'))
# print model summary
ae_model.summary()

# wrap mnist_model into a framework independent class structure
mymodel = KerasClassifier(mnist_model)

# initialize explainer object
explainer = CEMExplainer_pp_path(mymodel)

# EXPLAIN AN INSTANCE
input_image = data.test_data[image_id]


# check model prediction
print("Predicted class:", mymodel.predict_classes(np.expand_dims(input_image, axis=0)))
print("Predicted logits:", mymodel.predict(np.expand_dims(input_image, axis=0)))

# Set parameters for PSEM
arg_max_iter = 100 # Maximum number of iterations
arg_init_const = 10.0 # Initial coefficient value for main loss term that encourages class change
arg_b = 9 # No. of updates to the coefficient of the main loss term

arg_kappa = 0.75 # Minimum confidence gap between the PNs (changed) class probability and original class' probability
arg_beta = [1e-4, 1e-3, 1e-2, 1e-1, 1.0] # path of betas for path of PP
arg_gamma = 100 # Controls how much to adhere to a (optionally trained) autoencoder
arg_offset = 0.5 # the model predicts on data in [-0.5, 0.5]
arg_eta = 10.0 # Controls distance between input image and PP (added for Path PP)
arg_alpha = 0.0
arg_threshold = 1.0
arg_normalize = True

arg_mode = "PP_PATH"  # Find psem path
(adv_psem, delta_psem, info_psem) = explainer.explain_instance(np.expand_dims(input_image, axis=0), arg_mode, ae_model, arg_kappa, arg_b, arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha, arg_threshold, arg_offset, arg_eta, arg_normalize)

fig, axarr = plt.subplots(1, len(arg_beta)+1, figsize=(8,3))
fig_delta = {}
for i in range(len(delta_psem)):
    fig_delta[i] =  delta_psem[0].copy() # start with original image
    fig_delta[i] -= np.min(fig_delta[i])
    fig_delta[i] /= np.max(fig_delta[i])
    fig_delta[i] = gray2rgb(fig_delta[i].reshape((28,28)))
    if i > 0: # for images along PSEM path highlight what PSEM learned
        curr_delta = delta_psem[i].copy()
        curr_delta -= np.min(curr_delta)
        curr_delta /= np.max(curr_delta)
        curr_delta = gray2rgb(curr_delta.reshape((28,28)))
        inds = np.where(curr_delta > .2)
        fig_delta[i][inds[0],inds[1],0] = 1.
        fig_delta[i][inds[0],inds[1],1] = 0.
        fig_delta[i][inds[0],inds[1],2] = 0.

    axarr[i].axis('off')
    if i == 0:
        axarr[i].set_title('Original', fontsize=12, fontweight="bold")
    else:
        axarr[i].set_title('PSEM-{}'.format(i), fontsize=12, fontweight="bold")
    axarr[i].imshow(fig_delta[i])
plt.show()