{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Tutorial Getting Started\n",
        "\n",
        "This is a notebook with instructions how of to get started with GEF.\n",
        "\n",
        "[\"Explanation Faithfulness is Alignment: A Unifying and Ge-\n",
        "ometric Perspective on Interpretability Evaluation\"](link)"
      ],
      "metadata": {
        "id": "nXaRtM8T_uGl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Installs and Imports"
      ],
      "metadata": {
        "id": "SwbOS_83AIKE"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DMJO65Ss_oon"
      },
      "outputs": [],
      "source": [
        "!pip install -r ../requirements.txt --quiet"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "# Load supporting functions.\n",
        "import sys\n",
        "sys.path.append('../')\n",
        "from src import *"
      ],
      "metadata": {
        "id": "i4p5GwOy3m3Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Load toy data and model"
      ],
      "metadata": {
        "id": "nyQo-W1h3oAj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Load example data and models.\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "model_mnist = LeNet()\n",
        "model_mnist.load_state_dict(torch.load('mnist_lenet', map_location=torch.device('cpu')))\n",
        "model_mnist.eval().to(device)\n",
        "\n",
        "# Load random samples.\n",
        "x_batch, y_batch = load_mnist_samples(n=50)\n",
        "y_preds = model_mnist(x_batch).argmax(dim=50)"
      ],
      "metadata": {
        "id": "z2IumT4uFDlF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Intialise the metric and evaluate"
      ],
      "metadata": {
        "id": "GoMNtrBM3sa6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Intialise GEF.\n",
        "metric = GEF()\n",
        "\n",
        "# Evaluate generalised explanation faithfulness.\n",
        "gef_scores = metric(model, x_batch, y_preds)\n",
        "\n",
        "# Evaluate generalised explanation faithfulness.\n",
        "fast_gef_scores = metric(model, x_batch, y_preds, fast_mode=True)"
      ],
      "metadata": {
        "id": "Mt6J7sGX3VXa"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}