{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation Pipeline\n",
    "\n",
    "Here, I am implementing the metrices used for evaluating the evaluation quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#imports \n",
    "\n",
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "from efficientnet.tfkeras import EfficientNetB4, EfficientNetB3\n",
    "from transformers import DetrFeatureExtractor, DetrForSegmentation\n",
    "from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
    "from sam2.build_sam import build_sam2\n",
    "from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
    "from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from Utilities.utilities import *\n",
    "from Utilities.lime_segmentation import *\n",
    "from Utilities.xai_evaluation import *\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "config = {\n",
    "    \n",
    "    \"XAI_algorithm\": {\n",
    "        \"DSEG\" : True,\n",
    "        \"LIME\" : True,\n",
    "        \"SLIME\" : True,\n",
    "        \"BayesLime\" : True,\n",
    "        \"GLIME\" : True,\n",
    "        \"EAC\" : False,\n",
    "        \"SHAP\" : False\n",
    "    },\n",
    "\n",
    "    'computation':{\n",
    "        'num_workers': 2,\n",
    "        'gpu_device': True,\n",
    "        'gpu_num': \"6\",\n",
    "    },\n",
    "    \n",
    "    'model_to_explain':{\n",
    "        'EfficientNet' : True,\n",
    "        'ResNet': False,\n",
    "        'VisionTransformer': False,\n",
    "    },\n",
    "    \n",
    "    'lime_segmentation':{\n",
    "        #Lime Parameters\n",
    "        'num_samples': 256,\n",
    "        'num_features': 1000,\n",
    "        'min_weight': 0.01,\n",
    "        'top_labels': 1,\n",
    "        'hide_color': None,\n",
    "        'batch_size': 10,\n",
    "        'verbose': True,\n",
    "        \n",
    "        'slic': True,\n",
    "        'quickshift': False,\n",
    "        'felzenszwalb': False,\n",
    "        'watershed': False,\n",
    "        \n",
    "        'all_dseg': True,\n",
    "        'DETR' : False,\n",
    "        'SAM' : True,\n",
    "        'points_per_side' : 32,\n",
    "        'min_size': 512,\n",
    "        \n",
    "        'fit_segmentation' : True,        \n",
    "        'slic_compactness' : 16,\n",
    "        'num_segments': 20,\n",
    "        'markers' : 16,\n",
    "        'kernel_size' : 6,\n",
    "        'max_dist' : 32,\n",
    "        \n",
    "        #Lime Segmentation Parameters\n",
    "        'iterations': 1,\n",
    "        'shuffle': False,\n",
    "        'max_segments': 8,\n",
    "        'min_segments': 1,\n",
    "        'auto_segment': False, \n",
    "        \n",
    "        # LIME Explanation Parameters\n",
    "        \"num_features_explanation\": 2,\n",
    "        \"adaptive_num_features\": False,\n",
    "        \"adaptive_fraction\": True,\n",
    "        \n",
    "        'hide_rest': True,\n",
    "        'positive_only': True,\n",
    "        \n",
    "    },\n",
    "    \n",
    "    \"evaluation\": {\n",
    "        'noisy_background': True,\n",
    "        \n",
    "        #Correctness\n",
    "        \"model_randomization\" : True,\n",
    "        \"explanation_randomization\" : True,\n",
    "        \n",
    "        \"single_deletion\": True,\n",
    "        \"fraction\" : 0.1,\n",
    "        \"fraction_std\" : 0.05,\n",
    "        \n",
    "        \"incremental_deletion\": True,\n",
    "        \"incremental_deletion_fraction\": 0.15,\n",
    "        \n",
    "        \"stability\": True,\n",
    "        \"repetitions\": 8,\n",
    "        \n",
    "        #Output Completeness\n",
    "        \"preservation_check\": True,\n",
    "        \"deletion_check\": True,\n",
    "        \n",
    "        #Consistency\n",
    "        \"variation_stability\": True,\n",
    "        \n",
    "        #Contrastivity\n",
    "        \"target_discrimination\": True,\n",
    "        \n",
    "        #Compactness\n",
    "        \"size\": True,\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_dataset_parallel_in_steps(model_eff, feature_extractor, dataset_path, model, model_test, config, steps = 5, date_path = None, model_explain_processor = None):\n",
    "    # Get the list of sub-images in the dataset path\n",
    "    sub_images_raw = os.listdir(dataset_path)\n",
    "    \n",
    "    if date_path == None:\n",
    "        date_path = time.strftime(\"%y_%m_%d_%H_%M_%S\")\n",
    "        sub_images = sub_images_raw[0:steps]\n",
    "        \n",
    "        image_dictionary = {}\n",
    "    \n",
    "    else:\n",
    "        \n",
    "        if os.path.exists('./Dataset/Results_50/'+date_path+'/'\"eval_results.pkl\"):\n",
    "            image_dictionary = pickle.load(open('./Dataset/Results_50/'+date_path+'/'\"eval_results.pkl\", \"rb\"))\n",
    "        \n",
    "        print(image_dictionary.keys())\n",
    "        \n",
    "        sub_images = []\n",
    "        for i in sub_images_raw:\n",
    "            if i not in image_dictionary.keys() and i != \".DS_Store\":\n",
    "                if steps > 0:\n",
    "                    sub_images.append(i)\n",
    "                    steps -= 1\n",
    "                \n",
    "    print(sub_images)\n",
    "    if len(sub_images) > 0:\n",
    "\n",
    "        # Create the 'Results' directory if it doesn't exist\n",
    "        pathlib.Path('./Dataset/Results_50/'+date_path).mkdir(parents=True, exist_ok=True)\n",
    "        json.dump(config, open('./Dataset/Results_50/'+date_path+'/config.json', 'w'))\n",
    "        # Create a Pool with the specified number of processes\n",
    "        parallel_data_generation = Parallel(n_jobs=config['computation']['num_workers'], verbose=3, backend='threading') #loky #sequential multiprocessing\n",
    "        results = parallel_data_generation(delayed(evaluate_sub_image)(sub_image, dataset_path, date_path, model_eff, feature_extractor, model, model_test, config, model_explain_processor = model_explain_processor) for sub_image in sub_images)\n",
    "\n",
    "\n",
    "        # Populate the image dictionary with results\n",
    "        for sub_image, result in results:\n",
    "            image_dictionary[sub_image] = result\n",
    "\n",
    "        pickle.dump(image_dictionary, open('./Dataset/Results_50/'+date_path+'/'\"eval_results.pkl\", \"wb\"))\n",
    "        \n",
    "        # Return the image dictionary\n",
    "        return dataset_path\n",
    "    else:\n",
    "        print(\"error\")\n",
    "        return None\n",
    "    \n",
    "\n",
    "def evaluate_stable_dataset_parallel_in_steps(model_eff, mask_generator, dataset_path, model_sam, model_eff_b3, config, model_explain_processor = None, date_path = None, steps = 3, num_cpus=None):\n",
    "    \n",
    "    if date_path == None:\n",
    "        date_path = time.strftime(\"%y_%m_%d_%H_%M_%S\")\n",
    "\n",
    "    for i in range(4):\n",
    "        try:\n",
    "            result = evaluate_dataset_parallel_in_steps(model_eff, mask_generator, dataset_path, model_sam, model_eff_b3, config, steps = 3, model_explain_processor = model_explain_processor,date_path = date_path)\n",
    "        except:\n",
    "            print(\"Error\")\n",
    "            continue\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = config['computation']['gpu_num'] if config['computation']['gpu_device'] else ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' if config['computation']['gpu_device'] else ''\n",
    "\n",
    "os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda-11.4' if config['computation']['gpu_device'] else ''#-10.1' #--xla_gpu_cuda_data_dir=/usr/local/cuda, \n",
    "os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 ,--tf_xla_enable_xla_devices' if config['computation']['gpu_device'] else ''#'--tf_xla_auto_jit=2' #, --tf_xla_enable_xla_devices\n",
    "os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'\n",
    "\n",
    "image_path = \"./Dataset/COCO/val2017/\"\n",
    "annotation_path = \"./Dataset/COCO/stuff_val2017_pixelmaps/\"\n",
    "json_path = \"./Dataset/COCO/stuff_val2017.json\"\n",
    "\n",
    "model_eff_b3 = EfficientNetB3(weights='imagenet')\n",
    "\n",
    "if config['model_to_explain']['EfficientNet']: \n",
    "    model_explain = EfficientNetB4(weights='imagenet')\n",
    "    model_explain_processor = None \n",
    "elif config['model_to_explain']['ResNet']:\n",
    "    model_explain = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)\n",
    "    model_explain.eval()\n",
    "    model_explain_processor = None\n",
    "elif config['model_to_explain']['VisionTransformer']:\n",
    "    model_explain_processor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384')\n",
    "    model_explain = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384')\n",
    "\n",
    "if config['lime_segmentation']['DETR']:\n",
    "    mask_generator = DetrFeatureExtractor.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n",
    "    model = DetrForSegmentation.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n",
    "\n",
    "else:\n",
    "    sam2_checkpoint = \"./pretrained/sam2_hiera_large.pt\"\n",
    "    model_cfg = \"sam2_hiera_l.yaml\"\n",
    "    model = build_sam2(model_cfg, sam2_checkpoint, device ='cuda:0', apply_postprocessing=False)\n",
    "\n",
    "    if config['computation'].get('gpu_device'):\n",
    "        model.to(device=\"cuda:0\")\n",
    "    else:\n",
    "        model.to(device=\"cpu\")\n",
    "        \n",
    "\n",
    "    mask_generator = SAM2AutomaticMaskGenerator(model,\n",
    "                                            #pred_iou_thresh=0.7,\n",
    "                                            stability_score_thresh=0.8,\n",
    "                                            )\n",
    "   \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_path = \"./Dataset/Evaluation/\"\n",
    "\n",
    "data_path = evaluate_dataset_parallel_in_steps(model_explain, mask_generator, dataset_path, model, model_eff_b3, config, steps = 1, model_explain_processor = model_explain_processor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_path = \"./Dataset/Evaluation/\"\n",
    "# define dataset name!\n",
    "for i in range(50):\n",
    "    data_path = evaluate_dataset_parallel_in_steps(model_explain, mask_generator, dataset_path, model, model_eff_b3, config, steps = 1, model_explain_processor = model_explain_processor, date_path = \"dataset name\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "XAI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
