{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 186975,
     "status": "ok",
     "timestamp": 1652349123752,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "8hceYqOsrvR9",
    "outputId": "6df5082a-a5e0-4e3c-ff53-73621808d249"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 8796,
     "status": "ok",
     "timestamp": 1652360733714,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "fy_wPf0hr3cK"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import deel.lip as lip\n",
    "from deel.lip.layers import *\n",
    "from deel.lip.losses import *\n",
    "from deel.lip.utils import *\n",
    "\n",
    "import dlt\n",
    "import dlt.data.loader as loader\n",
    "import dlt.data.pipeline as pipeline\n",
    "import dlt.data.augmentation as aug\n",
    "import dlt.infrastructure.distributed_training as distributed\n",
    "from dlt.extras.layers import skip_connections as skips\n",
    "from dlt.model_factory import *\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from xplique.attributions import *\n",
    "from xplique.metrics import *\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import spearmanr\n",
    "from scipy.ndimage import gaussian_filter\n",
    "\n",
    "import tensorflow.keras.optimizers as optimizers\n",
    "from tensorflow.keras.models import Sequential, Model\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "\n",
    "plt.style.use('seaborn')\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "def set_size(w,h):\n",
    "  \"\"\"Set matplot figure size\"\"\"\n",
    "  plt.rcParams[\"figure.figsize\"] = [w,h]\n",
    "\n",
    "def show(img, p=False, smooth=False, **kwargs):\n",
    "  \"\"\"handle imshow of images and cmaps\"\"\"\n",
    "  img = np.array(img, dtype=np.float32)\n",
    "\n",
    "  # check if channel first\n",
    "  if img.shape[0] == 1:\n",
    "    img = img[0]\n",
    "  elif img.shape[0] == 3:\n",
    "    img = np.moveaxis(img, 0, 2)\n",
    "  # check if cmap\n",
    "  if img.shape[-1] == 1:\n",
    "    img = img[:,:,0]\n",
    "  # normalize\n",
    "  if img.max() > 1 or img.min() < 0:\n",
    "    img -= img.min(); img/=img.max()\n",
    "  # check if clip percentile\n",
    "  if p is not False:\n",
    "    img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))\n",
    "  \n",
    "  if smooth and len(img.shape) == 2:\n",
    "    img = gaussian_filter(img, smooth)\n",
    "\n",
    "  plt.imshow(img, **kwargs)\n",
    "  plt.axis('off')\n",
    "  plt.grid(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1652360735186,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "JORuiURyvCTh"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import (\n",
    "    load_model as keras_load_model,\n",
    "    model_from_json as keras_model_from_json,\n",
    "    model_from_yaml as keras_model_from_yaml,\n",
    ")\n",
    "\n",
    "model_path = \"../../outputs\"\n",
    "\n",
    "batch_size = 512\n",
    "norm_factor = 255.0\n",
    "alphaHKR = 20.0\n",
    "min_margin = 0.2\n",
    "learning_rate = 1e-2\n",
    "selected_classes = np.array([2, 4, 7, 8, 9, 10, 11, 15, 17, 18, 20, 21, 22, 28, 29, 30, 31, 34, 35, 36, 38, 39])\n",
    "\n",
    "def load_network(net_path,net_name, use_json=False):\n",
    "    modelPath2 = os.path.join(net_path,net_name)\n",
    "    if use_json:\n",
    "        json_file = open(modelPath2+'.json', 'r')\n",
    "        loaded_model_json = json_file.read()\n",
    "        json_file.close()\n",
    "        loaded_model = keras_model_from_json(loaded_model_json)\n",
    "        # load weights into new model\n",
    "        loaded_model.load_weights(modelPath2+'.h5')\n",
    "    else:\n",
    "        loaded_model = keras_load_model(modelPath2+'.h5',compile=False)\n",
    "    #loaded_model.summary()\n",
    "    print(\"Loaded model from disk\")\n",
    "    return loaded_model"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2845,
     "status": "ok",
     "timestamp": 1652360738025,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "GnQ6riROu2-L",
    "outputId": "79152adc-ec77-4e93-99cd-e09b94c50444"
   },
   "source": [
    "AUTO = tf.data.AUTOTUNE\n",
    "\n",
    "gcs_base_dir = \"gs://celeb_a_dataset/\"\n",
    "celeb_a_builder = tfds.builder(\"celeb_a\", data_dir=gcs_base_dir, version='2.0.0')\n",
    "celeb_a_builder.download_and_prepare()\n",
    "\n",
    "num_test_shards_dict = {'0.3.0': 4, '2.0.0': 2} # Used because we download the test dataset separately\n",
    "version = str(celeb_a_builder.info.version)\n",
    "print('Celeb_A dataset version: %s' % version)\n",
    "\n",
    "def make_ds():\n",
    "  ds = celeb_a_builder.as_dataset(split='test')\n",
    "  return ds\n",
    "\n",
    "def make_binary_dataset(ds, label_name):\n",
    "  ds = ds.map(lambda data: (tf.cast(data['image'], tf.float32) / norm_factor, data['attributes'][label_name]), num_parallel_calls=AUTO)\n",
    "  ds = ds.map(lambda x,y: (x, tf.one_hot(tf.cast(y, tf.uint8), 2)), num_parallel_calls=AUTO)\n",
    "  return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1841,
     "status": "ok",
     "timestamp": 1652360739862,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "UlMP7NPS8G4D",
    "outputId": "de9b8ed4-767a-4e75-afd0-4bef96f049b4"
   },
   "outputs": [],
   "source": [
    "ds_train, ds_test, metadata = loader.get_celeb_a_multilabel()\n",
    "x,y = next(iter(ds_train))\n",
    "print(x.shape)\n",
    "print(tf.reduce_min(x))\n",
    "print(tf.reduce_max(x))\n",
    "print(tf.reduce_mean(x))\n",
    "print(y.shape)\n",
    "print(y)\n",
    "print(metadata[\"class_names\"])\n",
    "class_names=metadata[\"class_names\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1841,
     "status": "ok",
     "timestamp": 1652360739862,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "UlMP7NPS8G4D",
    "outputId": "de9b8ed4-767a-4e75-afd0-4bef96f049b4"
   },
   "outputs": [],
   "source": [
    "def setup_data_singlelabel(ds_train, ds_test, chosen_label = 0, batch_size =1):\n",
    "    ds_train_single, ds_test_single = pipeline.prepare_data(\n",
    "        ds_train,\n",
    "        ds_test,\n",
    "        preparation_x=[\n",
    "            lambda x: tf.cast(x, dtype=tf.float32) / norm_factor,\n",
    "        ],\n",
    "        preparation_y=[\n",
    "            lambda y: tf.one_hot(tf.cast(y[all_selected_classes][chosen_label], dtype=tf.int32),2) \n",
    "        ],\n",
    "        augmentation_x=[\n",
    "        ],\n",
    "        batch_size=batch_size,\n",
    "        shuffle_test=False\n",
    "    )\n",
    "    return ds_train_single, ds_test_single\n"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1841,
     "status": "ok",
     "timestamp": 1652360739862,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "UlMP7NPS8G4D",
    "outputId": "de9b8ed4-767a-4e75-afd0-4bef96f049b4"
   },
   "source": [
    "ds = make_ds()\n",
    "data = [d for d in ds.take(1)][0]\n",
    "\n",
    "class_names = np.array([k for k in data['attributes'].keys()])\n",
    "assert class_names[0] == '5_o_Clock_Shadow'\n",
    "assert class_names[-1] == 'Young'\n",
    "\n",
    "print('x:', data['image'].shape, ' attributes:', class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1652360739862,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "khgc9Zu_uDDM",
    "outputId": "fe342990-cc96-4b86-9e13-5940ee1517b4"
   },
   "outputs": [],
   "source": [
    "if selected_classes is None:\n",
    "    nb_classes = len(class_names)\n",
    "    all_selected_classes = np.ones((nb_classes,)).astype(np.bool)\n",
    "else:\n",
    "    print(list(selected_classes))\n",
    "    nb_classes = len(set(selected_classes))\n",
    "    one_hot = np.eye(len(class_names))[selected_classes]\n",
    "    all_selected_classes = np.sum(one_hot,axis=0).astype(np.bool)\n",
    "print(all_selected_classes)\n",
    "print(\"nb_classes \", nb_classes)\n",
    "print(\"Selected classes :\", np.asarray(class_names)[all_selected_classes])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1652360739863,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "TzIJ_zt1-fE7"
   },
   "outputs": [],
   "source": [
    "def make_binary_classif_model(loaded_model, last_layer, label, reverse=1.0, biais=0.0):\n",
    "\n",
    "    output_logits = loaded_model.layers[last_layer].output.shape[-1]\n",
    "    assert label<output_logits\n",
    "\n",
    "    logits_layer = loaded_model.layers[last_layer]\n",
    "    out = Dense(2, activation=None, name='classif')(loaded_model.layers[last_layer].output)\n",
    "\n",
    "    last_weights  = np.zeros((output_logits,2))\n",
    "    last_weights[label,:] = reverse * np.asarray([[-1, 1]])\n",
    "\n",
    "    model = tf.keras.Model(loaded_model.inputs, out)\n",
    "    \n",
    "    model.layers[-1].set_weights([last_weights, np.asarray([reverse * biais, -reverse * biais])])\n",
    "    model.built = True\n",
    "    return model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "MpJoZtbS7LUI"
   },
   "source": [
    "# Loading Model and preparing dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1652360739864,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "_JuBpzUOPxL5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 93649,
     "status": "ok",
     "timestamp": 1652360833506,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "LNzbg_AxA_T-",
    "outputId": "3d07da11-c149-4ece-c76a-486710b50693"
   },
   "outputs": [],
   "source": [
    "vmodel_name = \"vmodel_22labels_None_HingeVar_alpha20.0_mm0.2_epoch100_fullmodel_thresh\"\n",
    "model_name = \"model_22labels_None_HingeVar_alpha20.0_mm0.2_epoch100_fullmodel\"\n",
    "#model_name = \"model_22labels_None_HingeVar_alpha20.0_mm0.2_epoch100_fullmodel\"\n",
    "chosen_label_name = 'Mustache' #'Wearing_Hat' # 'Gray_Hair'  #'Bald' #'Heavy_Makeup' #'Wearing_Lipstick' #'Mustache'\n",
    "chosen_label = np.argmax(np.asarray(class_names)[all_selected_classes]==chosen_label_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 93649,
     "status": "ok",
     "timestamp": 1652360833506,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "LNzbg_AxA_T-",
    "outputId": "3d07da11-c149-4ece-c76a-486710b50693"
   },
   "outputs": [],
   "source": [
    "ds_train_single, ds_test_single = setup_data_singlelabel(ds_train, ds_test, chosen_label = chosen_label, batch_size = -1)\n",
    "def make_ds():\n",
    "    return ds_test_single\n",
    "def make_binary_dataset(ds, label_name):\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 93649,
     "status": "ok",
     "timestamp": 1652360833506,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "LNzbg_AxA_T-",
    "outputId": "3d07da11-c149-4ece-c76a-486710b50693"
   },
   "outputs": [],
   "source": [
    "if os.path.exists(os.path.join(model_path,vmodel_name+\".npy\")):\n",
    "    best_thresh = np.load(os.path.join(model_path,vmodel_name+\".npy\"))\n",
    "    \n",
    "model = load_network(model_path,vmodel_name)\n",
    "binary_model = make_binary_classif_model(model, last_layer=-1, label=chosen_label)   \n",
    "\n",
    "loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
    "#HKR_multilabel(alpha = alphaHKR,  min_margin=min_margin)\n",
    "\n",
    "binary_model.compile(loss=loss, metrics=[\"accuracy\"], optimizer=optimizers.Adam(learning_rate=learning_rate))\n",
    "\n",
    "ds_binary = make_binary_dataset(make_ds(), chosen_label_name)\n",
    "binary_model.evaluate(ds_binary.batch(256))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "0a-8rud_-DYu"
   },
   "source": [
    "# Test gradient explanations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 34692,
     "status": "ok",
     "timestamp": 1652349495682,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "GgLtQim3s87L",
    "outputId": "948952de-7308-4df1-ba91-5feec99d777e"
   },
   "outputs": [],
   "source": [
    "explainer = Saliency(binary_model, batch_size=128)\n",
    "set_size(20, 20)\n",
    "\n",
    "phis = None\n",
    "for bx, by in ds_binary.batch(25).take(1):\n",
    "  phis = explainer(bx, by)\n",
    "  \n",
    "  i = 0\n",
    "  for x,y,h in zip(bx, by, phis):\n",
    "    plt.subplot(5, 5, i+1)\n",
    "    show(x)\n",
    "    show(h, cmap='jet', alpha=0.42, p=0.0001, smooth=False)\n",
    "    i+=1\n",
    "  plt.savefig('robust_grad.png', dpi=300)\n",
    "  plt.show()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "scJyLDqw-F6L"
   },
   "source": [
    "# (1) Computing Explanations Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 392,
     "status": "ok",
     "timestamp": 1652349574914,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "EDuujeAe-FYi"
   },
   "outputs": [],
   "source": [
    "from xplique.attributions import (Saliency, GradientInput, GradCAM, IntegratedGradients, \n",
    "                                  SmoothGrad, VarGrad, SquareGrad, Occlusion, Rise, \n",
    "                                  GuidedBackprop, GradCAMPP, Lime, KernelShap, SobolAttributionMethod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13211985,
     "status": "ok",
     "timestamp": 1651969229424,
     "user": {
      "displayName": "thomas fel",
      "userId": "05846348858113125044"
     },
     "user_tz": -120
    },
    "id": "NjggdtAN7Zve",
    "outputId": "e6d7cf41-e468-468c-f087-ae8d3671a32b"
   },
   "outputs": [],
   "source": [
    "nb_elements = 1000\n",
    "batch_size  = 128\n",
    "\n",
    "X_test = []\n",
    "Y_test = []\n",
    "\n",
    "for x,y in ds_binary.take(nb_elements):\n",
    "  X_test.append(x)\n",
    "  Y_test.append(y)\n",
    "\n",
    "X_test = np.array(X_test, np.float32)\n",
    "Y_test = np.array(Y_test, np.float32)\n",
    "\n",
    "np.save('X_test.npy', X_test)\n",
    "np.save('Y_test.npy', Y_test)\n",
    "\n",
    "explainers = [\n",
    "              Saliency(binary_model, -1, batch_size = batch_size),\n",
    "              GradientInput(binary_model, -1, batch_size = batch_size),\n",
    "              GradCAM(binary_model, -1, batch_size = batch_size),\n",
    "              GradCAMPP(binary_model, -1, batch_size = batch_size),\n",
    "              IntegratedGradients(binary_model, -1, steps = 20, batch_size = batch_size),\n",
    "              SmoothGrad(binary_model, -1, nb_samples = 20, batch_size = batch_size),\n",
    "              VarGrad(binary_model, -1, nb_samples = 20, batch_size = batch_size),\n",
    "              SquareGrad(binary_model, -1, nb_samples = 20, batch_size = batch_size),\n",
    "              Occlusion(binary_model, batch_size = batch_size, patch_size=10, patch_stride=10),\n",
    "              Rise(binary_model, batch_size = batch_size, nb_samples = 3000),\n",
    "              SobolAttributionMethod(binary_model, nb_design = 32, batch_size = batch_size, grid_size = 7),\n",
    "]\n",
    "\n",
    "for explainer in explainers:\n",
    "  explainer_name = explainer.__class__.__name__\n",
    "  phis = explainer(X_test, Y_test)\n",
    "\n",
    "  phis = np.abs(phis)\n",
    "  if len(phis.shape) == 4:\n",
    "    phis = np.max(phis, -1)\n",
    "  \n",
    "  if phis.shape[1] != X_test.shape[1]:\n",
    "    phis = np.moveaxis(phis, 1, 2)\n",
    "  \n",
    "  phis = np.array(phis, np.float16)\n",
    "\n",
    "  np.save(f'{model_name}_{explainer_name}.npy', phis) \n",
    "  print('done for ', explainer_name, phis.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kzL_yV6lC-s0"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JMdv4nY1C_g9"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3320,
     "status": "ok",
     "timestamp": 1652360836815,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "p65nzdy2YCZr",
    "outputId": "99ed681b-ca17-4489-f83f-91e754ce6a30"
   },
   "outputs": [],
   "source": [
    "metric_model = load_network(model_path,vmodel_name)\n",
    "metric_model = make_binary_classif_model(metric_model, last_layer=-1, label=chosen_label)   \n",
    "metric_model.layers[-1].activation = tf.keras.activations.softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f9ROcDVouSr3"
   },
   "outputs": [],
   "source": [
    "#!cp *.npy \"/content/drive/MyDrive/HKR XAI/data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2048235,
     "status": "ok",
     "timestamp": 1651971450945,
     "user": {
      "displayName": "thomas fel",
      "userId": "05846348858113125044"
     },
     "user_tz": -120
    },
    "id": "XJicd5ls_j-9",
    "outputId": "d28ed156-1bdc-4c1f-ff96-976144b4b624"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "x_max, x_min = X_test.max(), X_test.min()\n",
    "baseline_zero    = 0.0\n",
    "baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min \n",
    "\n",
    "metrics = [\n",
    "           MuFidelity(binary_model, X_test, Y_test, batch_size = batch_size, grid_size = 8, nb_samples = 50, baseline_mode = baseline_uniform),\n",
    "           Deletion(binary_model, X_test, Y_test, steps = 6, baseline_mode = baseline_uniform),\n",
    "           Insertion(binary_model, X_test, Y_test, steps = 6, baseline_mode = baseline_uniform),\n",
    "]\n",
    "\n",
    "for metric in metrics:\n",
    "  metric_name = metric.__class__.__name__\n",
    "  table = []\n",
    "  \n",
    "  for explainer in explainers:\n",
    "    explainer_name = explainer.__class__.__name__ \n",
    "    phis = np.load(f'{model_name}_{explainer_name}.npy')\n",
    "    phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "    score = metric(phis)\n",
    "    table.append((metric_name, model_name, explainer_name, score))\n",
    "    print(metric_name, explainer_name, score)\n",
    "  \n",
    "  np.save(f'results_{model_name}_{metric_name}_uniform', table)\n",
    "  print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1934754,
     "status": "ok",
     "timestamp": 1651973385689,
     "user": {
      "displayName": "thomas fel",
      "userId": "05846348858113125044"
     },
     "user_tz": -120
    },
    "id": "Vp-Y1V4zSf0Z",
    "outputId": "46ed23dd-31a9-4277-8f03-7f6570e4e773"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "x_max, x_min = X_test.max(), X_test.min()\n",
    "baseline_zero    = 0.0\n",
    "baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min \n",
    "\n",
    "metrics = [\n",
    "           MuFidelity(binary_model, X_test, Y_test, batch_size = batch_size, grid_size = 8, nb_samples = 50, baseline_mode = baseline_zero),\n",
    "           Deletion(binary_model, X_test, Y_test, steps = 6, baseline_mode = baseline_zero),\n",
    "           Insertion(binary_model, X_test, Y_test, steps = 6, baseline_mode = baseline_zero),\n",
    "]\n",
    "\n",
    "for metric in metrics:\n",
    "  metric_name = metric.__class__.__name__\n",
    "  table = []\n",
    "  \n",
    "  for explainer in explainers:\n",
    "    explainer_name = explainer.__class__.__name__ \n",
    "    phis = np.load(f'{model_name}_{explainer_name}.npy')\n",
    "    phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "    score = metric(phis)\n",
    "    table.append((metric_name, model_name, explainer_name, score))\n",
    "    print(metric_name, explainer_name, score)\n",
    "  \n",
    "  np.save(f'results_{model_name}_{metric_name}_zero', table)\n",
    "  print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ayh5DuKi5RXW"
   },
   "outputs": [],
   "source": [
    "!cp results_* \"/content/drive/MyDrive/HKR XAI/results\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "lO9yVJJ2U0gq"
   },
   "source": [
    "# (2)Robustness-sr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 404,
     "status": "ok",
     "timestamp": 1652367349634,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "2U5orZWBVjeB"
   },
   "outputs": [],
   "source": [
    "\n",
    "def epsilons_l2(dimensions, epsilon, nb_samples, masks):\n",
    "  normal_deviations = np.random.normal(size=(nb_samples, dimensions))\n",
    "  radius = np.linalg.norm(normal_deviations, axis=1, keepdims=True)\n",
    "  normal_deviations = normal_deviations * (np.random.rand(nb_samples, dimensions) ** (1.0 / dimensions))\n",
    "\n",
    "  points = normal_deviations * epsilon / radius\n",
    "  samples = tf.cast(points * masks, tf.float32)\n",
    "\n",
    "  return samples\n",
    "\n",
    "@tf.function\n",
    "def _pgd_step(model, inputs, labels):\n",
    "  with tf.GradientTape() as tape:\n",
    "    tape.watch(inputs)\n",
    "    outputs = model(inputs) \n",
    "    loss = tf.keras.losses.binary_crossentropy(labels, outputs, from_logits=False)\n",
    "\n",
    "  grads = tape.gradient(loss, inputs)\n",
    "  return grads, loss\n",
    "\n",
    "\n",
    "def _pgd_projection(grads, step_size, adversarials, inputs, epsilons):\n",
    "    grad_l2_norm = tf.linalg.norm(tf.reshape(grads, (len(grads), -1)), axis=-1)\n",
    "    grad_l2_norm = grad_l2_norm[:, tf.newaxis, tf.newaxis, tf.newaxis]\n",
    "    grads = grads * step_size / grad_l2_norm\n",
    "\n",
    "    adversarials += grads\n",
    "    delta = adversarials - inputs\n",
    "\n",
    "    # project the element to the closest point of the l2 ball if necessary\n",
    "    delta_norm = tf.linalg.norm(tf.reshape(delta, (len(delta), -1)), axis=-1)\n",
    "    need_projection = tf.cast(delta_norm <= epsilons, tf.float32)\n",
    "    \n",
    "    scaling_factor = delta_norm * (1.0 - need_projection) + epsilons * need_projection\n",
    "\n",
    "    delta *= tf.reshape(epsilons / scaling_factor, (len(epsilons), 1, 1, 1))\n",
    "\n",
    "    return delta\n",
    "\n",
    "@tf.function\n",
    "def _iter_delta(model, adversarials, labels, masks, step_size, inputs, epsilons):\n",
    "  grads, loss = _pgd_step(model, adversarials, labels)\n",
    "  grads *= masks\n",
    "\n",
    "  delta = _pgd_projection(grads, step_size, adversarials, inputs, epsilons)\n",
    "  return delta\n",
    "\n",
    "#@tf.function\n",
    "def projected_gradient_descent(model, inputs, labels, epsilons, masks, \n",
    "                               iterations=100, step_size=1e-2, batch_size = 100):\n",
    "  \n",
    "  adversarials = tf.cast(inputs, tf.float32)\n",
    "  inputs = tf.cast(inputs, tf.float32)\n",
    "  masks = tf.cast(masks, tf.float32)\n",
    "\n",
    "  epsilons_step = epsilons / iterations\n",
    "  #print('epsilons_step', epsilons_step)\n",
    "\n",
    "  #print('masks ? ', masks.shape)\n",
    "  \"\"\"\n",
    "  for iteration_i in tf.range(iterations):\n",
    "    grads, _ = _pgd_step(model, adversarials, labels)\n",
    "    grads = grads * masks\n",
    "    \n",
    "    scaler = tf.reduce_sum(tf.abs(grads), (1, 2, 3))\n",
    "    scaler = scaler\n",
    "    scaler = scaler[:, None, None, None]\n",
    "\n",
    "    grads = grads / scaler * epsilons_step[:, None, None, None]\n",
    "    #grads = grads\n",
    "    #print('applying grad of norm', tf.reduce_sum(tf.abs(grads), (1,2,3)), \"epsilons:\", epsilons)\n",
    "    adversarials = adversarials + grads \n",
    "\n",
    "  print('adversarial / inputs differences', tf.reduce_sum(tf.abs(adversarials - inputs), (1,2,3)) ) \n",
    "  \"\"\"\n",
    "  start_idx = 0\n",
    "  end_idx = batch_size\n",
    "  adv_list = []\n",
    "  success_list = []\n",
    "\n",
    "  while start_idx<len(inputs):\n",
    "      end_idx = min(end_idx,len(inputs))\n",
    "      advs = adversarials[start_idx:end_idx]\n",
    "      for iteration_i in tf.range(iterations):\n",
    "        delta = _iter_delta(model, advs, labels[start_idx:end_idx], masks[start_idx:end_idx], step_size, inputs[start_idx:end_idx], epsilons[start_idx:end_idx])\n",
    "        advs += delta\n",
    "      adv_list.append(advs)\n",
    "      #adversarials = tf.clip_by_value(adversarials, -1.0, 1.0)\n",
    "      success = tf.argmax(model(inputs[start_idx:end_idx]), -1) != tf.argmax(model(advs), -1)\n",
    "      success_list.append(success)\n",
    "      start_idx += batch_size\n",
    "      end_idx += batch_size\n",
    "  adversarials = tf.concat(adv_list,axis=0) \n",
    "  success = tf.concat(success_list,axis=0) \n",
    "  return success, adversarials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1652367350206,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "IzyNPRuC0xAj"
   },
   "outputs": [],
   "source": [
    "def pgd_bisection(model, inputs, masks, labels,\n",
    "                  start_epsilon=1e-3,\n",
    "                  tolerance=1e-5, max_steps=100, \n",
    "                  upscale_rate=10.0, downscale_rate=10.0,\n",
    "                  iterations=10):\n",
    "  \"\"\"\n",
    "  Bisection algorithm to find the min epsilon\n",
    "  \"\"\"\n",
    "  nb_elements = len(inputs)\n",
    "\n",
    "  upper_not_found = np.ones(nb_elements)\n",
    "  upper_epsilons = np.full(nb_elements, np.inf)\n",
    "  lower_not_found = np.ones(nb_elements)\n",
    "  lower_epsilons = np.full(nb_elements, -np.inf)\n",
    "\n",
    "  epsilons = np.ones(nb_elements, dtype=np.float32) * start_epsilon\n",
    "\n",
    "  optimal_adversarials = np.zeros(inputs.shape)\n",
    "\n",
    "  step = 0\n",
    "  # we start with all element\n",
    "  ids_to_finds = np.arange(nb_elements)\n",
    "\n",
    "  while len(ids_to_finds):\n",
    "    successes, adversarials = projected_gradient_descent(model, inputs[ids_to_finds], labels[ids_to_finds], \n",
    "                                                         epsilons=epsilons[ids_to_finds], masks=masks[ids_to_finds],\n",
    "                                                         iterations=iterations)\n",
    "    \n",
    "    #print('one pgd done', len(ids_to_finds), epsilons)\n",
    "\n",
    "    for i, input_id in enumerate(ids_to_finds):\n",
    "      eps = epsilons[input_id]\n",
    "\n",
    "      if successes[i]:\n",
    "        # update the current adversarial\n",
    "        optimal_adversarials[input_id] = adversarials[i]\n",
    "        # we have an adv, so it's an epsilon upper-bound\n",
    "        upper_epsilons[input_id] = epsilons[input_id]\n",
    "\n",
    "        if lower_not_found[input_id]:\n",
    "          # if we have not found a lower bound, we keep decreasing\n",
    "          epsilons[input_id] /= downscale_rate \n",
    "        else:\n",
    "          # we already reach a lower bound, so we shouldn't divide by downscale\n",
    "          epsilons[input_id] =  0.5 * (lower_epsilons[input_id] + upper_epsilons[input_id])\n",
    "\n",
    "        upper_not_found[input_id] = False\n",
    "      \n",
    "      else:\n",
    "        # we have not found an adv, so it's an epsilon lower-bound\n",
    "        lower_epsilons[input_id] = epsilons[input_id]\n",
    "\n",
    "        if upper_not_found[input_id]:\n",
    "          # if we have not found an adv yet, we keep increasing\n",
    "          epsilons[input_id] *= upscale_rate \n",
    "        else:\n",
    "          # we already have and adv, we should go slower\n",
    "          epsilons[input_id] =  0.5 * (lower_epsilons[input_id] + upper_epsilons[input_id])\n",
    "\n",
    "        lower_not_found[input_id] = False\n",
    "    \n",
    "    step += 1\n",
    "    ids_not_found = (upper_epsilons - lower_epsilons) > tolerance\n",
    "    ids_to_finds = np.arange(nb_elements)[ids_not_found]\n",
    "    \n",
    "    if step >= max_steps:\n",
    "        break\n",
    "\n",
    "  return upper_epsilons, optimal_adversarials\n",
    "\n",
    "def robustness_sr(model, inputs, phis, labels, start_epsilon=5e-3,\n",
    "                  tolerance=5e-5, max_steps=25, \n",
    "                  upscale_rate=10.0, downscale_rate=10.0,\n",
    "                  iterations=100,\n",
    "                  verbose=False):\n",
    "    inputs = np.array(inputs)\n",
    "    labels = np.array(labels)\n",
    "    phis = np.array(phis)\n",
    "    percents = [0.05 * i for i in range(1, 5)]\n",
    "    \n",
    "    # flatten phis to select the best K-ids\n",
    "    original_shape = phis.shape\n",
    "    phis = np.array(phis).reshape((len(phis), -1))\n",
    "    \n",
    "    epsilons = []\n",
    "\n",
    "    for percent in percents:\n",
    "      masks = np.zeros_like(phis, dtype=np.float32)\n",
    "      for phi_id, phi in enumerate(phis):\n",
    "        to_select = np.argsort(phi)[::-1][:int(len(phi) * percent)]\n",
    "        masks[phi_id, to_select] = 1.0\n",
    "\n",
    "      masks = masks.reshape(original_shape)\n",
    "      #print(\"masks ?\", masks.shape)\n",
    "      eps, _ = pgd_bisection(model, inputs, masks, labels,\n",
    "                             start_epsilon=start_epsilon,\n",
    "                             tolerance=tolerance, max_steps=max_steps, \n",
    "                             upscale_rate=upscale_rate, downscale_rate=downscale_rate,\n",
    "                             iterations=iterations)\n",
    "      \n",
    "      # remove adversarial that has not been found\n",
    "      # avoid inf. value by using a ceil number, could be improved\n",
    "      eps = np.clip(eps, 0, 100.0)\n",
    "      epsilons.append(eps)\n",
    "    if verbose:\n",
    "      print('epsilons:', epsilons)\n",
    "      plt.hist(epsilons, bins=len(epsilons))\n",
    "      plt.xscale('log')\n",
    "      plt.show()\n",
    "    \n",
    "    avg_epsilon = np.mean(epsilons, -1)\n",
    "    auc = np.mean(avg_epsilon[:-1] + avg_epsilon[1:]) * 0.5 # trapezoidal rule\n",
    "\n",
    "    return auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 2305,
     "status": "ok",
     "timestamp": 1652367353856,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "i2OHYGNqYCQm"
   },
   "outputs": [],
   "source": [
    "X_test, Y_test = np.load('../data/X_test.npy'), np.load('../data/Y_test.npy')\n",
    "Y_test = tf.one_hot(np.argmax(metric_model.predict(X_test, batch_size = 256), -1), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 2305,
     "status": "ok",
     "timestamp": 1652367353856,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "i2OHYGNqYCQm"
   },
   "outputs": [],
   "source": [
    "#start_explain = 5\n",
    "#print(explainers_names[start_explain:])\n",
    "\n",
    "#print(explainers_names[0,1,2])\n",
    "\n",
    "explainers_names = [\n",
    "    #'GradCAM',\n",
    " #'GradientInput',\n",
    " #'IntegratedGradients',\n",
    " #'Rise',\n",
    " 'Saliency',\n",
    " 'SmoothGrad']\n",
    "print(explainers_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 416
    },
    "executionInfo": {
     "elapsed": 596834,
     "status": "error",
     "timestamp": 1652367988073,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "EQUANJmdwJ5z",
    "outputId": "a19264dc-e7ce-4508-95fa-39dfd93668e2"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "metric_name = 'robustness_sr'\n",
    "table = []\n",
    "\n",
    "l = 1000\n",
    "\n",
    "for explainer_name in explainers_names:\n",
    "  #explainer_name = explainer.__class__.__name__ \n",
    "  phis = np.load(f'../data/{model_name}_{explainer_name}.npy')\n",
    "  phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "  score = robustness_sr(metric_model, X_test[:l], phis[:l], Y_test[:l], \n",
    "                        start_epsilon = 2.0, tolerance = 1e-2, max_steps = 20,\n",
    "                        iterations = 200)\n",
    "\n",
    "  #score = metric(phis)\n",
    "  table.append((metric_name, model_name, explainer_name, score))\n",
    "  print(metric_name, explainer_name, score)\n",
    "\n",
    "np.save(f'results_{model_name}_{metric_name}', table)\n",
    "print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 602
    },
    "executionInfo": {
     "elapsed": 70331,
     "status": "error",
     "timestamp": 1652367060754,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "-ZeuPxVFU3X4",
    "outputId": "d91ce15c-d262-4b03-a006-4c1c97fb7e84"
   },
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "metric_name = 'robustness_sr'\n",
    "table = []\n",
    "\n",
    "l = 5\n",
    "\n",
    "for explainer_name in explainers_names:\n",
    "  #explainer_name = explainer.__class__.__name__ \n",
    "  phis = np.load(f'data/{model_name}_{explainer_name}.npy')\n",
    "  phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "  score = robustness_sr(metric_model, X_test[:l], phis[:l], Y_test[:l], \n",
    "                        start_epsilon = 5e-3, tolerance = 1e-2, max_steps = 10,\n",
    "                        iterations = 100)\n",
    "\n",
    "  #score = metric(phis)\n",
    "  table.append((metric_name, model_name, explainer_name, score))\n",
    "  print(metric_name, explainer_name, score)\n",
    "\n",
    "np.save(f'results_{model_name}_{metric_name}', table)\n",
    "print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tRh2s7kFR7eJ"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nrGDT0i4R7gn"
   },
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "G7WXaleiOz5V"
   },
   "source": [
    "# (2.bis) MuFidelity re-run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 45993,
     "status": "ok",
     "timestamp": 1652349668597,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "gx-9p2ZhPAaf"
   },
   "outputs": [],
   "source": [
    "!cp -r \"/content/drive/MyDrive/HKR XAI/data\" ./"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 631,
     "status": "ok",
     "timestamp": 1652360837443,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "99O8n9RGO8O8"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1652360837444,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "sHxrqwSBQ9M_",
    "outputId": "fb9c1f20-6346-4fee-d11c-3a612c0661a7"
   },
   "outputs": [],
   "source": [
    "explainers_names = [p for p in os.listdir('../data') if 'results' not in p and 'model' in p]\n",
    "explainers_names = [p for p in explainers_names if model_name in p]\n",
    "explainers_names = [p.split('_')[-1].split('.npy')[0] for p in explainers_names]\n",
    "explainers_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 591842,
     "status": "ok",
     "timestamp": 1652350260437,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "fOg8SAMfO3Qd",
    "outputId": "dd99c1dd-bf32-4b6c-d716-c43b122f4a1b"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "l = 100\n",
    "batch_size  = 128\n",
    "\n",
    "x_max, x_min = X_test.max(), X_test.min()\n",
    "baseline_zero    = 0.0\n",
    "baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min \n",
    "\n",
    "for grid_size in [20]:\n",
    "\n",
    "  metric = MuFidelity(binary_model, X_test[:l], Y_test[:l], batch_size = batch_size, grid_size = grid_size, nb_samples = 200, baseline_mode = baseline_uniform)\n",
    "  metric_name = metric.__class__.__name__\n",
    "\n",
    "  table = []\n",
    "  \n",
    "  for explainer_name in explainers_names:\n",
    "    #explainer_name = explainer.__class__.__name__ \n",
    "    phis = np.load(f'data/{model_name}_{explainer_name}.npy')[:l]\n",
    "    phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "    score = metric(phis)\n",
    "    table.append((metric_name, model_name, explainer_name, score))\n",
    "    print(metric_name, explainer_name, score)\n",
    "  \n",
    "  np.save(f'results_{model_name}_{metric_name}_uniform_{grid_size}', table)\n",
    "  print('done for ', grid_size)\n",
    "  print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 576912,
     "status": "ok",
     "timestamp": 1652350899943,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "EBzFnKsVwL8o",
    "outputId": "80e563ba-f67a-44bd-94f7-660528c8fb9a"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "l = 1000\n",
    "batch_size  = 128\n",
    "\n",
    "x_max, x_min = X_test.max(), X_test.min()\n",
    "baseline_zero    = 0.0\n",
    "baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min \n",
    "\n",
    "for grid_size in [20]:\n",
    "\n",
    "  metric = MuFidelity(binary_model, X_test[:l], Y_test[:l], batch_size = batch_size, grid_size = grid_size, nb_samples = 200, baseline_mode = baseline_zero)\n",
    "  metric_name = metric.__class__.__name__\n",
    "\n",
    "  table = []\n",
    "  \n",
    "  for explainer_name in explainers_names:\n",
    "    #explainer_name = explainer.__class__.__name__ \n",
    "    phis = np.load(f'data/{model_name}_{explainer_name}.npy')[:l]\n",
    "    phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "    score = metric(phis)\n",
    "    table.append((metric_name, model_name, explainer_name, score))\n",
    "    print(metric_name, explainer_name, score)\n",
    "  \n",
    "  np.save(f'results_{model_name}_{metric_name}_zero_{grid_size}', table)\n",
    "  print('done for ', grid_size)\n",
    "  print('\\n\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 391,
     "status": "ok",
     "timestamp": 1652351173701,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "HSm3NxeszdtT"
   },
   "outputs": [],
   "source": [
    "l = 200\n",
    "sa_phis = np.load(f'data/{model_name}_Saliency.npy')[:l]\n",
    "sa_phis = tf.cast(sa_phis, tf.float32)[:,:,:]\n",
    "\n",
    "sg_phis = np.load(f'data/{model_name}_SmoothGrad.npy')[:l]\n",
    "sg_phis = tf.cast(sg_phis, tf.float32)[:,:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 369
    },
    "executionInfo": {
     "elapsed": 11402,
     "status": "ok",
     "timestamp": 1652351286662,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "Bu_UanQVze4u",
    "outputId": "8b416b41-e529-4f34-b61b-a44e73fb10e8"
   },
   "outputs": [],
   "source": [
    "set_size(20, 5)\n",
    "\n",
    "j = 10\n",
    "\n",
    "for i in range(5):\n",
    "  plt.subplot(2, 5, i+1)\n",
    "  show(X_test[i+j])\n",
    "  show(sa_phis[i+j], cmap='jet', alpha=0.4, p=0.5)\n",
    "  plt.subplot(2, 5, 5+i+1)\n",
    "  show(X_test[i+j])\n",
    "  show(sg_phis[i+j], cmap='jet', alpha=0.4, p=0.5)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('sa_vs_sg_robust.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9426,
     "status": "ok",
     "timestamp": 1652352070803,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "UfjNLMR90hH3",
    "outputId": "e9640c7b-31d9-45f7-aa18-70b80678d3df"
   },
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy.stats import spearmanr\n",
    "\n",
    "l = 1000\n",
    "sa_phis = np.load(f'data/{model_name}_Saliency.npy')\n",
    "sa_phis = tf.cast(sa_phis, tf.float32).numpy()\n",
    "\n",
    "sg_phis = np.load(f'data/{model_name}_SmoothGrad.npy')\n",
    "sg_phis = tf.cast(sg_phis, tf.float32).numpy()\n",
    "\n",
    "spearman_dist = []\n",
    "l2_dist = []\n",
    "\n",
    "for h1, h2 in zip(sg_phis, sa_phis):\n",
    "  spearman_dist.append(spearmanr(h1.flatten(), h2.flatten())[0])\n",
    "  l2_dist.append(np.mean(np.sqrt(np.square(h1 - h2))))\n",
    "\n",
    "\n",
    "np.save('sa_sg_spearman_robust.npy', np.array(spearman_dist))\n",
    "np.save('sa_sg_l2_robust.npy', np.array(l2_dist))\n",
    "\n",
    "print('mean spearman', np.mean(spearman_dist), \"Mean l2\", np.mean(l2_dist))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "HLTZfq0d3Pci"
   },
   "source": [
    "# Kolmogorov estimation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3592,
     "status": "ok",
     "timestamp": 1652352463571,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "RyJgerlJ3Rd3",
    "outputId": "4dd76d7b-aa24-450b-a3f9-948bb0132d39"
   },
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "sizes = []\n",
    "\n",
    "sa_phis = np.load(f'data/{model_name}_Saliency.npy')\n",
    "\n",
    "for phi in sa_phis:\n",
    "  phi -= phi.min()\n",
    "  phi /= phi.max()\n",
    "  phi *= 255.0\n",
    "  phi = np.array(phi).astype(np.uint8)\n",
    "\n",
    "  cv2.imwrite('img.jpg', phi)\n",
    "  sz = os.path.getsize(\"img.jpg\")\n",
    "\n",
    "  sizes.append(sz)\n",
    "\n",
    "np.save('robust_saliency_jpg_size.npy', sizes)\n",
    "print('mean bytes size', np.mean(sizes), \"std\", np.std(sizes))\n",
    "print('mean kilo-bytes', np.mean(sizes) / 1_000, \"std\", np.std(sizes) / 1_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BBs_CCkRgvzr",
    "outputId": "bc900fc3-5ae3-4327-8713-a9702673bb77"
   },
   "outputs": [],
   "source": [
    "from xplique.metrics import MuFidelity, Deletion, Insertion\n",
    "\n",
    "l = 1000\n",
    "batch_size  = 128\n",
    "\n",
    "x_max, x_min = X_test.max(), X_test.min()\n",
    "baseline_zero    = 0.0\n",
    "baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min \n",
    "\n",
    "grid_size = 20\n",
    "\n",
    "metric = MuFidelity(binary_model, X_test[:l], Y_test[:l], batch_size = batch_size, grid_size = grid_size, nb_samples = 250, baseline_mode = baseline_uniform)\n",
    "metric_name = metric.__class__.__name__\n",
    "\n",
    "table = []\n",
    "\n",
    "for explainer_name in explainers_names:\n",
    "  #explainer_name = explainer.__class__.__name__ \n",
    "  phis = np.load(f'data/{model_name}_{explainer_name}.npy')[:l]\n",
    "  phis = tf.cast(phis, tf.float32)[:,:,:,None]\n",
    "\n",
    "  score = metric(phis)\n",
    "  table.append((metric_name, model_name, explainer_name, score))\n",
    "  print(metric_name, explainer_name, score)\n",
    "\n",
    "\n",
    "filename = f'results_{model_name}_{metric_name}_uniform'\n",
    "np.save(filename, table)\n",
    "print('done for ', grid_size)\n",
    "print('\\n\\n')\n",
    "!cp {filename} \"/content/drive/MyDrive/HKR XAI/results\"\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0W-R6htwR7jF"
   },
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "c_5FiAHUU3vC"
   },
   "source": [
    "# (3) Stability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 480952,
     "status": "ok",
     "timestamp": 1652369478850,
     "user": {
      "displayName": "River Dread",
      "userId": "09094954806195578933"
     },
     "user_tz": -120
    },
    "id": "aG9FYbeVU5fl",
    "outputId": "88a1bc44-82ce-4a48-aa30-56f130dc0336"
   },
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "\n",
    "l = 500\n",
    "\n",
    "def spearman_dist(phi1, phi2):\n",
    "  if len(phi1.shape) == 3:\n",
    "    phi1 = np.mean(phi1, -1)\n",
    "    phi2 = np.mean(phi2, -1)\n",
    "  return 1.0 - np.abs(spearmanr(np.array(phi1).flatten(), np.array(phi2).flatten()))\n",
    "\n",
    "def l2_dist(phi1, phi2):\n",
    "  if len(phi1.shape) == 3:\n",
    "    phi1 = np.mean(phi1, -1)\n",
    "    phi2 = np.mean(phi2, -1)\n",
    "  return np.mean((phi1 - phi2)**2.0)\n",
    "\n",
    "metric = AverageStability(binary_model, X_test[:l], Y_test[:l], batch_size = 128, radius = 0.1, \n",
    "                          distance = spearman_dist, nb_samples = 10)\n",
    "\n",
    "results = []\n",
    "\n",
    "for explainer in [\n",
    "                  Saliency(binary_model),\n",
    "                  IntegratedGradients(binary_model, steps = 80),\n",
    "                  SmoothGrad(binary_model, nb_samples = 80)\n",
    "]:\n",
    "  explainer_name = explainer.__class__.__name__\n",
    "  \n",
    "  score = metric.evaluate(explainer)\n",
    "  results.append((explainer_name, score))\n",
    "  print(explainer_name, score)\n",
    "\n",
    "np.save('stability_spearman_robust.npy', results)\n",
    "\n",
    "from scipy.stats import spearmanr\n",
    "\n",
    "\n",
    "metric = AverageStability(binary_model, X_test[:l], Y_test[:l], batch_size = 128, radius = 0.1, \n",
    "                          distance = l2_dist, nb_samples = 20)\n",
    "\n",
    "results = []\n",
    "\n",
    "for explainer in [\n",
    "                  Saliency(binary_model),\n",
    "                  IntegratedGradients(binary_model, steps = 80),\n",
    "                  SmoothGrad(binary_model, nb_samples = 80)\n",
    "]:\n",
    "  explainer_name = explainer.__class__.__name__\n",
    "  \n",
    "  score = metric.evaluate(explainer)\n",
    "  results.append((explainer_name, score))\n",
    "  print(explainer_name, score)\n",
    "\n",
    "np.save('stability_l2_robust.npy', results)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qoQ8LRjBR7op"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qOJGwm5VDPeJ"
   },
   "outputs": [],
   "source": [
    "!cp results_* \"/content/drive/MyDrive/HKR XAI/results\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FGLBKkVTL4En"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "name": "001 (Robust) CelebA compute explanations.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
