{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "path_to_add = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "if path_to_add not in sys.path:\n",
    "    sys.path.append(path_to_add)\n",
    "\n",
    "import time\n",
    "\n",
    "from efficientnet.tfkeras import EfficientNetB4\n",
    "from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor, AutoImageProcessor, DetrForSegmentation, DetrFeatureExtractor\n",
    "import torch\n",
    "from segment_anything import SamAutomaticMaskGenerator, sam_model_registry\n",
    "\n",
    "from Utilities.utilities import *\n",
    "from Utilities.lime_segmentation import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \n",
    "    \"XAI_algorithm\": {\n",
    "        \"DSEG\" : True,\n",
    "        \"LIME\" : True,\n",
    "        \"SLIME\" : True,\n",
    "        \"BayesLime\" : True,\n",
    "        \"GLIME\" : True,\n",
    "    },\n",
    "\n",
    "    'computation':{\n",
    "        'num_workers': 3,\n",
    "        'gpu_device': False,\n",
    "        'gpu_num': \"4\",\n",
    "    },\n",
    "    \n",
    "        \n",
    "    'model_to_explain':{\n",
    "        'EfficientNet' : True,\n",
    "        'ResNet': False,\n",
    "        'VisionTransformer': False,\n",
    "    },\n",
    "    \n",
    "    'lime_segmentation':{\n",
    "        #Lime Parameters\n",
    "        'num_samples': 128,\n",
    "        'num_features': 1000,\n",
    "        'min_weight': 0.01,\n",
    "        'top_labels': 1,\n",
    "        'hide_color': None,\n",
    "        'batch_size': 10,\n",
    "        'verbose': True,\n",
    "        \n",
    "        'slic': True,\n",
    "        'quickshift': False,\n",
    "        'felzenszwalb': False,\n",
    "        'watershed': False,\n",
    "        \n",
    "        'all_dseg': False,\n",
    "        'DETR' : False,\n",
    "        'SAM' : True,\n",
    "        'points_per_side' : 32,\n",
    "        'min_size': 512,\n",
    "        \n",
    "        'fit_segmentation' : True,        \n",
    "        'slic_compactness' : 16,\n",
    "        'num_segments': 20,\n",
    "        'markers' : 16,\n",
    "        'kernel_size' : 6,\n",
    "        'max_dist' : 32,\n",
    "        \n",
    "        #Lime Segmentation Parameters\n",
    "        'iterations': 1,\n",
    "        'shuffle': False,\n",
    "        'max_segments': 8,\n",
    "        'min_segments': 1,\n",
    "        'auto_segment': False, \n",
    "        \n",
    "        # LIME Explanation Parameters\n",
    "        \"num_features_explanation\": 2,\n",
    "        \"adaptive_num_features\": False,\n",
    "        \"adaptive_fraction\": True,\n",
    "        \n",
    "        'hide_rest': True,\n",
    "        'positive_only': True,\n",
    "        \n",
    "    },\n",
    "    \n",
    "    \"evaluation\": {\n",
    "        'noisy_background': True,\n",
    "        \n",
    "        #Correctness\n",
    "        \"model_randomization\" : True,\n",
    "        \"explanation_randomization\" : True,\n",
    "        \n",
    "        \"single_deletion\": True,\n",
    "        \"fraction\" : 0.1,\n",
    "        \"fraction_std\" : 0.05,\n",
    "        \n",
    "        \"incremental_deletion\": True,\n",
    "        \"incremental_deletion_fraction\": 0.15,\n",
    "        \n",
    "        \"stability\": True,\n",
    "        \"repetitions\": 8,\n",
    "        \n",
    "        #Output Completeness\n",
    "        \"preservation_check\": True,\n",
    "        \"deletion_check\": True,\n",
    "        \n",
    "        #Consistency\n",
    "        \"variation_stability\": True,\n",
    "        \n",
    "        #Contrastivity\n",
    "        \"target_discrimination\": True,\n",
    "        \n",
    "        #Compactness\n",
    "        \"size\": True,\n",
    "    }\n",
    "}\n",
    "\n",
    "\n",
    "from Utilities.utilities import *\n",
    "from Utilities.lime_utilities import *\n",
    "from Utilities.lime_segmentation import *\n",
    "from Utilities.lime_base import *\n",
    "from Utilities.GLIME import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = config['computation']['gpu_num'] if config['computation']['gpu_device'] else ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' if config['computation']['gpu_device'] else ''\n",
    "\n",
    "os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda-11.4' if config['computation']['gpu_device'] else ''#-10.1' #--xla_gpu_cuda_data_dir=/usr/local/cuda, \n",
    "os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 ,--tf_xla_enable_xla_devices' if config['computation']['gpu_device'] else ''#'--tf_xla_auto_jit=2' #, --tf_xla_enable_xla_devices\n",
    "os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'\n",
    "\n",
    "if config['model_to_explain']['EfficientNet']: \n",
    "    model_explain = EfficientNetB4(weights='imagenet')\n",
    "    model_explain_processor = None \n",
    "elif config['model_to_explain']['ResNet']:\n",
    "    model_explain = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)\n",
    "    model_explain.eval()\n",
    "    model_explain_processor = None\n",
    "elif config['model_to_explain']['VisionTransformer']:\n",
    "    model_explain_processor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384')\n",
    "    model_explain = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384')\n",
    "\n",
    "if config['lime_segmentation']['DETR']:\n",
    "    feature_extractor = AutoImageProcessor.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n",
    "    model = DetrForSegmentation.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n",
    "\n",
    "else:\n",
    "    sam_checkpoint = \"../Models/pretrained/sam_vit_h_4b8939.pth\"\n",
    "    model_type = \"vit_h\"\n",
    "    model = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
    "    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    if config['computation'].get('gpu_device'):\n",
    "        model.to(device=\"cuda:0\")\n",
    "    else:\n",
    "        model.to(device=\"cpu\")\n",
    "\n",
    "    feature_extractor = SamAutomaticMaskGenerator(model,\n",
    "                                            min_mask_region_area=config['lime_segmentation']['min_size'],\n",
    "                                            points_per_side=config['lime_segmentation']['points_per_side'],\n",
    "                                            )\n",
    "   \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please change image path to test your own images\n",
    "path_image_id = \"../Dataset/Evaluation/n02134084_ice_bear.JPEG\"\n",
    "\n",
    "predict_image(path_image_id, model_explain, config, True, model_processor = model_explain_processor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = LimeImageExplainer()\n",
    "\n",
    "data, data_raw = load_and_preprocess_image(path_image_id, config, plot = False, model = model_explain_processor)\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "if config['lime_segmentation']['all_dseg']:\n",
    "    explanation = explainer.explain_instance(data, \n",
    "                                        model_explain,\n",
    "                                        feature_extractor,\n",
    "                                        model,\n",
    "                                        config = config,\n",
    "                                        shuffle = config['lime_segmentation']['shuffle'],\n",
    "                                        image_path = path_image_id, \n",
    "                                        top_labels=config['lime_segmentation']['top_labels'], \n",
    "                                        hide_color=config['lime_segmentation']['hide_color'], \n",
    "                                        num_samples=config['lime_segmentation']['num_samples'],\n",
    "                                        iterations= config['lime_segmentation']['iterations'],\n",
    "                                        segmentation_fn_seed = segment_seed_dynamic,\n",
    "                                        segmentation_fn_dynamic = segment_image_dynamic,\n",
    "                                        random_seed = 42,\n",
    "                                        model_regressor = \"Bayes_ridge\")\n",
    "else:\n",
    "    explanation = explainer.explain_instance(data, \n",
    "                                        model_explain,\n",
    "                                        None,\n",
    "                                        None,\n",
    "                                        config = config,\n",
    "                                        segmentation_fn=lime_segmentation,\n",
    "                                        top_labels=config['lime_segmentation']['top_labels'], \n",
    "                                        hide_color=config['lime_segmentation']['hide_color'], \n",
    "                                        num_samples=config['lime_segmentation']['num_samples'],\n",
    "                                        random_seed = 42,\n",
    "                                        model_regressor = \"Bayes_ridge\")\n",
    "\n",
    "end_time = time.time()\n",
    "print(\"Time: \", end_time - start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], \n",
    "                                            positive_only=config['lime_segmentation']['positive_only'], \n",
    "                                            num_features=1, \n",
    "                                            hide_rest=config['lime_segmentation']['hide_rest'])\n",
    "\n",
    "black = np.array([0, 0, 0], dtype=np.uint8)\n",
    "gray = np.array([230,230,230], dtype=np.uint8)\n",
    "\n",
    "#data = data.resize((mask.shape[0],mask.shape[1]))\n",
    "# Create a mask to detect all black pixels\n",
    "if config['model_to_explain']['EfficientNet']:\n",
    "    changed = data_raw*mask[:,:,np.newaxis]\n",
    "elif config['model_to_explain']['ResNet']:\n",
    "    data_raw = data_raw.resize((mask.shape[0],mask.shape[1]))\n",
    "    changed = data_raw*mask[:,:,np.newaxis]\n",
    "elif config['model_to_explain']['VisionTransformer']:\n",
    "    data_raw = data_raw.resize((mask.shape[0],mask.shape[1]))\n",
    "    changed = data_raw*mask[:,:,np.newaxis]\n",
    "changed = changed.astype(np.uint8)\n",
    "for i in range(changed.shape[0]):\n",
    "    for j in range(changed.shape[1]):\n",
    "        if (changed[i,j] == black).all():\n",
    "            changed[i,j] = gray\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(12, 5))  # 1 row, 2 columns\n",
    "\n",
    "\n",
    "# The first subplot for the marked image\n",
    "ax[0].imshow(changed)\n",
    "ax[0].axis('off')  # Hide the axis on the image plot\n",
    "ax[0].set_title('Marked Image')\n",
    "\n",
    "# Extract labels and values for the bar chart\n",
    "labels, values, _ = zip(*explanation.local_exp[explanation.top_labels[0]])\n",
    "\n",
    "# The second subplot for the bar chart\n",
    "ax[1].bar(labels, values, color='skyblue')\n",
    "\n",
    "# Add a title and labels to the axes for the bar chart\n",
    "ax[1].set_title('Bar Chart of Values')\n",
    "ax[1].set_xlabel('Label')\n",
    "ax[1].set_ylabel('Value')\n",
    "\n",
    "# Show grid lines for better readability on the bar chart\n",
    "ax[1].grid(True, which='both', linestyle='--', linewidth=0.5)\n",
    "\n",
    "# Handle negative and positive values in y-axis on the bar chart\n",
    "ax[1].axhline(0, color='grey', linewidth=0.8)\n",
    "\n",
    "# Adjust the layout so that both subplots fit well in the figure\n",
    "plt.tight_layout()\n",
    "\n",
    "# Display the plots\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
