{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT-explainability.ipynb",
      "provenance": [],
      "authorship_tag": "ABX9TyOm8dIRrumd5XNcc+fntVA5",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YCdGaMuy56TA",
        "outputId": "8f802262-55eb-4366-b772-89c4756224b3"
      },
      "source": [
        "!git clone https://github.com/hila-chefer/Transformer-Explainability.git\n",
        "\n",
        "import os\n",
        "os.chdir(f'./Transformer-Explainability')\n",
        "\n",
        "!pip install -r requirements.txt\n",
        "!pip install captum"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fatal: destination path 'Transformer-Explainability' already exists and is not an empty directory.\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: Pillow>=8.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 1)) (9.4.0)\n",
            "Requirement already satisfied: einops==0.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 2)) (0.3.0)\n",
            "Requirement already satisfied: h5py==2.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 3)) (2.8.0)\n",
            "Requirement already satisfied: imageio==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 4)) (2.9.0)\n",
            "Collecting matplotlib==3.3.2\n",
            "  Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
            "Requirement already satisfied: opencv_python in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 6)) (4.6.0.66)\n",
            "Requirement already satisfied: scikit_image==0.17.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 7)) (0.17.2)\n",
            "Requirement already satisfied: scipy==1.5.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 8)) (1.5.2)\n",
            "Requirement already satisfied: sklearn in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 9)) (0.0.post1)\n",
            "Requirement already satisfied: torch==1.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 10)) (1.7.0)\n",
            "Requirement already satisfied: torchvision==0.8.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 11)) (0.8.1)\n",
            "Requirement already satisfied: tqdm==4.51.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 12)) (4.51.0)\n",
            "Requirement already satisfied: transformers==3.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 13)) (3.5.1)\n",
            "Requirement already satisfied: utils==1.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 14)) (1.0.1)\n",
            "Requirement already satisfied: Pygments>=2.7.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 15)) (2.14.0)\n",
            "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.21.6)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.15.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (1.4.4)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (3.0.9)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.8.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (0.11.0)\n",
            "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2022.12.7)\n",
            "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (3.0)\n",
            "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (2022.10.10)\n",
            "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (1.4.1)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.6)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (4.4.0)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.16.0)\n",
            "Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.0.53)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.19.6)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.9.0)\n",
            "Requirement already satisfied: sentencepiece==0.1.91 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.1.91)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (21.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2022.6.2)\n",
            "Requirement already satisfied: tokenizers==0.9.3 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.9.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2.25.1)\n",
            "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (4.0.0)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (1.24.3)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (2.10)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (1.2.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (7.1.2)\n",
            "Installing collected packages: matplotlib\n",
            "  Attempting uninstall: matplotlib\n",
            "    Found existing installation: matplotlib 3.6.3\n",
            "    Uninstalling matplotlib-3.6.3:\n",
            "      Successfully uninstalled matplotlib-3.6.3\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed matplotlib-3.3.2\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: captum in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum) (3.3.2)\n",
            "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum) (1.7.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum) (1.21.6)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.16.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (4.4.0)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.6)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (0.11.0)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (9.4.0)\n",
            "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2022.12.7)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (1.4.4)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (3.0.9)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.15.0)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install captum==0.6.0\n",
        "!pip install matplotlib==3.3.2"
      ],
      "metadata": {
        "id": "zDPnh4lofcNw",
        "outputId": "3d585bbc-ff3b-4a09-b5bf-57bb4d46e830",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: captum==0.6.0 in /usr/local/lib/python3.8/dist-packages (0.6.0)\n",
            "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.7.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.21.6)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (3.6.3)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (4.4.0)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.16.0)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.6)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.4.4)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.0.7)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (9.4.0)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (2.8.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (21.3)\n",
            "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (3.0.9)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (4.38.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (0.11.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7->matplotlib->captum==0.6.0) (1.15.0)\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting matplotlib==3.3.2\n",
            "  Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (9.4.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (0.11.0)\n",
            "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.21.6)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (3.0.9)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.4.4)\n",
            "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2022.12.7)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.2) (1.15.0)\n",
            "Installing collected packages: matplotlib\n",
            "  Attempting uninstall: matplotlib\n",
            "    Found existing installation: matplotlib 3.6.3\n",
            "    Uninstalling matplotlib-3.6.3:\n",
            "      Successfully uninstalled matplotlib-3.6.3\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed matplotlib-3.3.2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4-XGl_Zw6Aht"
      },
      "source": [
        "from transformers import BertTokenizer\n",
        "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
        "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n",
        "from transformers import BertTokenizer\n",
        "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "from captum.attr import visualization\n",
        "import torch"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VakYjrkC6C3S"
      },
      "source": [
        "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n",
        "model.eval()\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\n",
        "# initialize the explanations generator\n",
        "explanations = Generator(model)\n",
        "\n",
        "classifications = [\"NEGATIVE\", \"POSITIVE\"]\n"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jGRp376FPOvV"
      },
      "source": [
        "#Positive sentiment example"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uSLZtv546H2z",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 219
        },
        "outputId": "26712e90-0b77-40b0-a908-fef13dd88bcd"
      },
      "source": [
        "# encode a sentence\n",
        "text_batch = [\"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.\"]\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
        "\n",
        "# true class is positive - 1\n",
        "true_class = 1\n",
        "\n",
        "# generate an explanation for the input\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
        "# normalize scores\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
        "\n",
        "# get the model classification\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
        "classification = output.argmax(dim=-1).item()\n",
        "# get class name\n",
        "class_name = classifications[classification]\n",
        "# if the classification is negative, higher explanation scores are more negative\n",
        "# flip for visualization\n",
        "if class_name == \"NEGATIVE\":\n",
        "  expl *= (-1)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\n",
        "                                expl,\n",
        "                                output[0][classification],\n",
        "                                classification,\n",
        "                                true_class,\n",
        "                                true_class,\n",
        "                                1,       \n",
        "                                tokens,\n",
        "                                1)]\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[('[CLS]', 0.0), ('this', 0.4267549514770508), ('movie', 0.30920878052711487), ('was', 0.2684089243412018), ('the', 0.33637329936027527), ('best', 0.6280889511108398), ('movie', 0.28546375036239624), ('i', 0.1863601952791214), ('have', 0.10115814208984375), ('ever', 0.1419338583946228), ('seen', 0.1898290067911148), ('!', 0.5944811105728149), ('some', 0.003896803595125675), ('scenes', 0.033401958644390106), ('were', 0.018588582053780556), ('ridiculous', 0.018908796831965446), (',', 0.0), ('but', 0.42920616269111633), ('acting', 0.43855082988739014), ('was', 0.500239372253418), ('great', 1.0), ('.', 0.014817383140325546), ('[SEP]', 0.0868983045220375)]\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> best                    </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> have                    </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ever                    </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> seen                    </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> !                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> scenes                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ridiculous                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> acting                    </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> best                    </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> have                    </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ever                    </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> seen                    </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> !                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> scenes                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ridiculous                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> acting                    </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oO_k1BtSPVt3"
      },
      "source": [
        "#Negative sentiment example"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 219
        },
        "id": "gD4xcvovI1KI",
        "outputId": "e4a50a94-da4c-460e-b602-052b09cec28f"
      },
      "source": [
        "# encode a sentence\n",
        "text_batch = [\"I really didn't like this movie. Some of the actors were good, but overall the movie was boring.\"]\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
        "\n",
        "# generate an explanation for the input\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n",
        "# normalize scores\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
        "\n",
        "# get the model classification\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
        "classification = output.argmax(dim=-1).item()\n",
        "# get class name\n",
        "class_name = classifications[classification]\n",
        "# if the classification is negative, higher explanation scores are more negative\n",
        "# flip for visualization\n",
        "if class_name == \"NEGATIVE\":\n",
        "  expl *= (-1)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\n",
        "                                expl,\n",
        "                                output[0][classification],\n",
        "                                classification,\n",
        "                                1,\n",
        "                                1,\n",
        "                                1,       \n",
        "                                tokens,\n",
        "                                1)]\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[('[CLS]', -0.0), ('i', -0.19109757244586945), ('really', -0.1888734996318817), ('didn', -0.2894313633441925), (\"'\", -0.006574898026883602), ('t', -0.36788827180862427), ('like', -0.15249046683311462), ('this', -0.18922168016433716), ('movie', -0.0404353104531765), ('.', -0.019592661410570145), ('some', -0.02311306819319725), ('of', -0.0), ('the', -0.02295113168656826), ('actors', -0.09577538073062897), ('were', -0.013370633125305176), ('good', -0.0323222391307354), (',', -0.004366681911051273), ('but', -0.05878860130906105), ('overall', -0.33596664667129517), ('the', -0.21820111572742462), ('movie', -0.05482065677642822), ('was', -0.6248231530189514), ('boring', -1.0), ('.', -0.031107747927308083), ('[SEP]', -0.052539654076099396)]\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> really                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> didn                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> '                    </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> t                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> like                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> of                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> actors                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> good                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> overall                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> boring                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> really                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> didn                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> '                    </font></mark><mark style=\"background-color: hsl(0, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> t                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> like                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> of                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> actors                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> good                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> overall                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 76%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> boring                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Choosing class for visualization example"
      ],
      "metadata": {
        "id": "UUn2_SMPNG-Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# encode a sentence\n",
        "text_batch = [\"I hate that I love you.\"]\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
        "\n",
        "# true class is positive - 1\n",
        "true_class = 1\n",
        "\n",
        "# generate an explanation for the input\n",
        "target_class = 0\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
        "# normalize scores\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
        "\n",
        "# get the model classification\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
        "\n",
        "# get class name\n",
        "class_name = classifications[target_class]\n",
        "# if the classification is negative, higher explanation scores are more negative\n",
        "# flip for visualization\n",
        "if class_name == \"NEGATIVE\":\n",
        "  expl *= (-1)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\n",
        "                                expl,\n",
        "                                output[0][classification],\n",
        "                                classification,\n",
        "                                true_class,\n",
        "                                true_class,\n",
        "                                1,       \n",
        "                                tokens,\n",
        "                                1)]\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "metadata": {
        "id": "VQVmMFnzhPoV",
        "outputId": "26a43f8a-340c-4821-b39c-80105a565810",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 219
        }
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[('[CLS]', -0.0), ('i', -0.19790242612361908), ('hate', -1.0), ('that', -0.40287283062934875), ('i', -0.12505637109279633), ('love', -0.1307140290737152), ('you', -0.05467141419649124), ('.', -6.108225989009952e-06), ('[SEP]', -0.0)]\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> hate                    </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> that                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> love                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> you                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> hate                    </font></mark><mark style=\"background-color: hsl(0, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> that                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> love                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> you                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# encode a sentence\n",
        "text_batch = [\"I hate that I love you.\"]\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\n",
        "\n",
        "# true class is positive - 1\n",
        "true_class = 1\n",
        "\n",
        "# generate an explanation for the input\n",
        "target_class = 1\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n",
        "# normalize scores\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\n",
        "\n",
        "# get the model classification\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n",
        "\n",
        "# get class name\n",
        "class_name = classifications[target_class]\n",
        "# if the classification is negative, higher explanation scores are more negative\n",
        "# flip for visualization\n",
        "if class_name == \"NEGATIVE\":\n",
        "  expl *= (-1)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\n",
        "                                expl,\n",
        "                                output[0][classification],\n",
        "                                classification,\n",
        "                                true_class,\n",
        "                                true_class,\n",
        "                                1,       \n",
        "                                tokens,\n",
        "                                1)]\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "metadata": {
        "id": "WiQAWw0-imCg",
        "outputId": "a8c66996-dcd0-4132-a8b0-2346d9bf9c7b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 219
        }
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[('[CLS]', 0.0), ('i', 0.2725590765476227), ('hate', 0.17270179092884064), ('that', 0.23211266100406647), ('i', 0.17642731964588165), ('love', 1.0), ('you', 0.2465524971485138), ('.', 0.0), ('[SEP]', 0.00015733683540020138)]\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> hate                    </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> that                    </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> love                    </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> you                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.91)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> hate                    </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> that                    </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> love                    </font></mark><mark style=\"background-color: hsl(120, 75%, 88%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> you                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    }
  ]
}