# -*- coding: utf-8 -*-
"""slice.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1d93sHxTgwIzLe2BjkP0r1PgqzNBSi_9F
"""

import os
import pickle
import numpy as np
import cv2
import keras
from keras.applications.imagenet_utils import decode_predictions
import skimage.io
from skimage.segmentation import quickshift, mark_boundaries
from skimage.measure import regionprops
import copy
import random
import sklearn
import sklearn.metrics
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from skimage import filters
import pandas as pd
import warnings
import tensorflow as tf
import pickle
import sys
from scipy.stats import kendalltau

from matplotlib import pyplot as plt
import time
from sklearn.utils import resample
from scipy.stats import norm, gaussian_kde
from sklearn.neighbors import KernelDensity
import csv
from slice.slice_explainer import SliceExplainer
from slice.vit_img_classifier import ViTImageClassifier
from transformers import ViTFeatureExtractor, TFViTForImageClassification


#Usage
img_dir = "images_oxpets/"
try:
    img_filenames = os.listdir(img_dir)
except IndexError:
    print("No files found in the directory.")

#Debugging
img_filenames = ['newfoundland_181.jpg']

algo_name = "semlime" # put "lime" for lime, "baylime" for baylime
# and "sliceblurfe" for slice and "semlime" for belief
results_dir = "results/"
num_runs = 20
sample_size = 500 # change as needed
tol = 3 # tolerance parameter for feature elimination algorithm.
# This is used in slice and not in other  methods


def get_model_params(model_name='resnet50'):
    if model_name == 'resnet50':
        model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
        preprocess_input = tf.keras.applications.resnet.preprocess_input
        decode_predictions = tf.keras.applications.resnet.decode_predictions
        target_img_size = (224, 224)
    elif model_name == 'inception_v3' or model_name == 'inceptionv3':
        model = tf.keras.applications.InceptionV3(weights='imagenet')
        preprocess_input = tf.keras.applications.inception_v3.preprocess_input
        decode_predictions = tf.keras.applications.inception_v3.decode_predictions
        target_img_size = (299, 299)
    elif model_name == 'vitp16':
        model_name = "google/vit-base-patch16-224"
        model = ViTImageClassifier(model_name)
        preprocess_input = model.preprocess_image
        target_img_size = (224,224)
    else:
        print("Unknown Model")

    return model, preprocess_input, target_img_size


img_info_filepath = "img_info_dict/imgs_info.pkl"
with open(img_info_filepath, 'rb') as f:
    img_info_dict = pickle.load(f)

model_names = ['inceptionv3'] #['resnet50', 'inceptionv3', 'vitp16']

for model_name in model_names:
    model, preprocess_input, target_img_size = get_model_params(model_name)

    for img_filename in img_filenames:
        img_dict = {}
        img_dir_name = (img_dir.split("_")[1]).split("/")[0]
        img_filename_split = (img_filename.split(".jpg") if ".jpg" in img_filename else img_filename.split(".png"))[0]

        pkl_filename = results_dir + img_dir_name + "/" + algo_name + "_" + model_name + "/" + img_dir_name + "_" + \
                       algo_name + "_" + img_filename_split + "_" + model_name + ".pkl"
        print(pkl_filename)
        #if os.path.exists(pkl_filename):
        #    continue
        #    continue

        if model_name == "vitp16":
            img_key = "resnet50" + "_" + img_filename_split
        else:
            img_key = model_name + "_" + img_filename_split

        # img_key = model_name + "_" + img_filename_split
        img_info = img_info_dict[img_key]
        segments = img_info[0]['segments']
        sel_sigma = img_info[0]['sel_sigma']

        run_dict = {}
        for i in np.arange(0, num_runs, step=1):
            # image_path=img_dir + img_filename
            exp = SliceExplainer(image_path=img_dir + img_filename, segments=segments, model=model, \
                                 target_img_size=target_img_size, preprocess=preprocess_input)
            print(img_filename, " : ", len(np.unique(exp.superpixels)), " : ", sample_size)

            pos_feature_ranks, neg_feature_ranks, pos_dict, neg_dict = exp.get_semlime_explanations(sigma=sel_sigma)
            ranks = {'pos': pos_feature_ranks.astype('int') if len(pos_feature_ranks) > 0 else np.array([]),
                     'neg': neg_feature_ranks.astype('int') if len(neg_feature_ranks) > 0 else np.array([]),
                     'pos_dict': pos_dict, 'neg_dict': neg_dict, 'sel_sigma': sel_sigma}

            print(f"Pos : {pos_feature_ranks}")
            print(f"Neg : {neg_feature_ranks}")

            if f'run_{i}' not in run_dict:
                run_dict[f'run_{i}'] = []

            del exp
            #
            run_dict[f'run_{i}'].append(ranks)
            img_key = img_filename.split('.')[0]

        if f'run_{img_key}' not in img_dict:
            img_dict[f'run_{img_key}'] = []

        img_dict[f'run_{img_key}'].append(run_dict)

        # # save the selected feature ranks in a dict and save it in a pkl file
        print(pkl_filename)
        # with open(pkl_filename, 'wb') as f1:
        #     pickle.dump(img_dict, f1)