import numpy as np
import skimage
import matplotlib.pyplot as plt
import matplotlib.cm as c_map
from skimage.segmentation import mark_boundaries

import tensorflow.keras
from tensorflow.keras.applications.imagenet_utils import decode_predictions

from lime.lime_image import LimeImageExplainer
from explainer.SvsvlExp_image import SvsvlImageExplainer #choquex.choquex_image import ChoquexImageExplainer


def generate_prediction_sample(exp, exp_class, weight = 0.1, show_positive = True, hide_background = True):
    '''
    Method to display and highlight super-pixels used by the black-box model to make predictions
    '''
    image, mask = exp.get_image_and_mask(exp_class, 
                                         positive_only=show_positive, 
                                         num_features=6, 
                                         hide_rest=hide_background,
                                         min_weight=weight
                                        )
    plt.imshow(mark_boundaries(image, mask))
    plt.axis('off')
    plt.show()
    generate_prediction_sample(exp, exp.top_labels[0], show_positive = True, hide_background = True)

Xi = skimage.io.imread("cat-and-dog.jpeg") # 
Xi = skimage.transform.resize(Xi, (299,299)) 
Xi = (Xi - 0.5)*2 #Inception pre-processing
#skimage.io.imshow(Xi/2+0.5) # Show image before inception preprocessing

#Predict class for image using InceptionV3
np.random.seed(222)
inceptionV3_model = tensorflow.keras.applications.inception_v3.InceptionV3() #Load pretrained model
preds = inceptionV3_model.predict(Xi[np.newaxis,:,:,:])
top_pred_classes = preds[0].argsort()[-5:][::-1] # Save ids of top 5 classes
decode_predictions(preds)[0] #Print top 5 classes


# explainer = LimeImageExplainer()
# exp = explainer.explain_instance(Xi, 
#                                  inceptionV3_model.predict, 
#                                  top_labels=240, 
#                                  hide_color=0, 
#                                  num_samples=1000)

# from skimage.segmentation import mark_boundaries
# c = next(iter(exp.local_exp))
# temp, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
# plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))



svsvlExp = SvsvlImageExplainer()
svsvlExplanation = svsvlExp.explain_instance(Xi, 
                                 inceptionV3_model.predict, 
                                 top_labels=5, 
                                 hide_color=0, 
                                 num_samples=5000)

temp, mask = svsvlExplanation.get_image_and_mask(svsvlExplanation.top_labels[0], negative_only=False, positive_only=True, num_features=1, hide_rest=True, interacting_features=3)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

print("Done!")

