{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT-explainability.ipynb",
      "provenance": [],
      "authorship_tag": "ABX9TyO4wn4/FB0J5pQCjUZNYulP",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YCdGaMuy56TA",
        "outputId": "0f2a5e1b-053d-4819-8d9b-729345cb2754"
      },
      "source": [
        "!git clone https://github.com/hila-chefer/Transformer-Explainability.git\r\n",
        "\r\n",
        "import os\r\n",
        "os.chdir(f'./Transformer-Explainability')\r\n",
        "\r\n",
        "!pip install -r requirements.txt"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'Transformer-Explainability'...\n",
            "remote: Enumerating objects: 281, done.\u001b[K\n",
            "remote: Counting objects: 100% (281/281), done.\u001b[K\n",
            "remote: Compressing objects: 100% (199/199), done.\u001b[K\n",
            "remote: Total 281 (delta 131), reused 181 (delta 69), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (281/281), 2.46 MiB | 3.62 MiB/s, done.\n",
            "Resolving deltas: 100% (131/131), done.\n",
            "Requirement already satisfied: Pillow==8.0.1 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (8.0.1)\n",
            "Requirement already satisfied: einops==0.3.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 2)) (0.3.0)\n",
            "Requirement already satisfied: h5py==2.8.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (2.8.0)\n",
            "Requirement already satisfied: imageio==2.9.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 4)) (2.9.0)\n",
            "Requirement already satisfied: matplotlib==3.3.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (3.3.2)\n",
            "Requirement already satisfied: numpy==1.19.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 6)) (1.19.2)\n",
            "Requirement already satisfied: opencv_python==3.4.2.17 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 7)) (3.4.2.17)\n",
            "Requirement already satisfied: scikit_image==0.17.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 8)) (0.17.2)\n",
            "Requirement already satisfied: scipy==1.5.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 9)) (1.5.2)\n",
            "Requirement already satisfied: sklearn in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 10)) (0.0)\n",
            "Requirement already satisfied: torch==1.7.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 11)) (1.7.0)\n",
            "Requirement already satisfied: torchvision==0.8.1 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 12)) (0.8.1)\n",
            "Requirement already satisfied: tqdm==4.51.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 13)) (4.51.0)\n",
            "Requirement already satisfied: transformers==3.5.1 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 14)) (3.5.1)\n",
            "Requirement already satisfied: utils==1.0.1 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 15)) (1.0.1)\n",
            "Requirement already satisfied: pygments==2.7.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 16)) (2.7.2)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.15.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (1.3.1)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (0.10.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.4.7)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.8.1)\n",
            "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2020.12.5)\n",
            "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 8)) (2021.3.4)\n",
            "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 8)) (1.1.1)\n",
            "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 8)) (2.5)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sklearn->-r requirements.txt (line 10)) (0.22.2.post1)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.7.0->-r requirements.txt (line 11)) (0.16.0)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.7/dist-packages (from torch==1.7.0->-r requirements.txt (line 11)) (0.6)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.7.0->-r requirements.txt (line 11)) (3.7.4.3)\n",
            "Requirement already satisfied: tokenizers==0.9.3 in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (0.9.3)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (20.9)\n",
            "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (0.0.43)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (2.23.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (3.0.12)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (2019.12.20)\n",
            "Requirement already satisfied: sentencepiece==0.1.91 in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (0.1.91)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.7/dist-packages (from transformers==3.5.1->-r requirements.txt (line 14)) (3.12.4)\n",
            "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from networkx>=2.0->scikit_image==0.17.2->-r requirements.txt (line 8)) (4.4.2)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sklearn->-r requirements.txt (line 10)) (1.0.1)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 14)) (7.1.2)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 14)) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 14)) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 14)) (1.24.3)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf->transformers==3.5.1->-r requirements.txt (line 14)) (54.0.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qEMoBJE859cj",
        "outputId": "4743747f-95f8-4586-f465-443b2b3facc9"
      },
      "source": [
        "pip install captum"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: captum in /usr/local/lib/python3.7/dist-packages (0.3.1)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from captum) (3.3.2)\n",
            "Requirement already satisfied: torch>=1.2 in /usr/local/lib/python3.7/dist-packages (from captum) (1.7.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from captum) (1.19.2)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (1.3.1)\n",
            "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (2020.12.5)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (2.8.1)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (0.10.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (2.4.7)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (8.0.1)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch>=1.2->captum) (0.16.0)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.7/dist-packages (from torch>=1.2->captum) (0.6)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.2->captum) (3.7.4.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.15.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4-XGl_Zw6Aht"
      },
      "source": [
        "from transformers import BertTokenizer\r\n",
        "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\r\n",
        "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\r\n",
        "from transformers import BertTokenizer\r\n",
        "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\r\n",
        "from transformers import AutoTokenizer\r\n",
        "\r\n",
        "from captum.attr import (\r\n",
        "    visualization\r\n",
        ")\r\n",
        "import torch"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BrC0d9Sp6Cjc"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VakYjrkC6C3S"
      },
      "source": [
        "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\r\n",
        "model.eval()\r\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\r\n",
        "# initialize the explanations generator\r\n",
        "explanations = Generator(model)\r\n",
        "\r\n",
        "classifications = [\"NEGATIVE\", \"POSITIVE\"]\r\n"
      ],
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jGRp376FPOvV"
      },
      "source": [
        "**Positive sentiment example**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uSLZtv546H2z",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 216
        },
        "outputId": "a9f7b34e-dd18-4b95-9d6b-914d85e95ff8"
      },
      "source": [
        "# encode a sentence\r\n",
        "text_batch = [\"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.\"]\r\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\r\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\r\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\r\n",
        "\r\n",
        "# true class is positive - 1\r\n",
        "true_class = 1\r\n",
        "\r\n",
        "# generate an explanation for the input\r\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\r\n",
        "# normalize scores\r\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\r\n",
        "\r\n",
        "# get the model classification\r\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\r\n",
        "classification = output.argmax(dim=-1).item()\r\n",
        "# get class name\r\n",
        "class_name = classifications[classification]\r\n",
        "# if the classification is negative, higher explanation scores are more negative\r\n",
        "# flip for visualization\r\n",
        "if class_name == \"NEGATIVE\":\r\n",
        "  expl *= (-1)\r\n",
        "\r\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\r\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\r\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\r\n",
        "                                expl,\r\n",
        "                                output[0][classification],\r\n",
        "                                classification,\r\n",
        "                                true_class,\r\n",
        "                                true_class,\r\n",
        "                                1,       \r\n",
        "                                tokens,\r\n",
        "                                1)]\r\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[('[CLS]', 0.0), ('this', 0.4398406744003296), ('movie', 0.3385171890258789), ('was', 0.2850261628627777), ('the', 0.3722951412200928), ('best', 0.6413642764091492), ('movie', 0.3098682463169098), ('i', 0.20284101366996765), ('have', 0.12214731425046921), ('ever', 0.15835356712341309), ('seen', 0.2082878053188324), ('!', 0.6001579761505127), ('some', 0.021879158914089203), ('scenes', 0.05488050356507301), ('were', 0.0371897891163826), ('ridiculous', 0.03780526667833328), (',', 0.02076297625899315), ('but', 0.44531309604644775), ('acting', 0.45006945729255676), ('was', 0.5168584585189819), ('great', 1.0), ('.', 0.035734280943870544), ('[SEP]', 0.10382220149040222)]\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 82%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(120, 75%, 68%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> best                    </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> have                    </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ever                    </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> seen                    </font></mark><mark style=\"background-color: hsl(120, 75%, 70%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> !                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> scenes                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ridiculous                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(120, 75%, 78%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(120, 75%, 78%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> acting                    </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 79%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(120, 75%, 84%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 82%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(120, 75%, 68%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> best                    </font></mark><mark style=\"background-color: hsl(120, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> have                    </font></mark><mark style=\"background-color: hsl(120, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ever                    </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> seen                    </font></mark><mark style=\"background-color: hsl(120, 75%, 70%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> !                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> scenes                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ridiculous                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(120, 75%, 78%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(120, 75%, 78%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> acting                    </font></mark><mark style=\"background-color: hsl(120, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oO_k1BtSPVt3"
      },
      "source": [
        "**Negative sentiment example**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 216
        },
        "id": "gD4xcvovI1KI",
        "outputId": "67842699-855a-4f22-de6d-6fd63d78f944"
      },
      "source": [
        "# encode a sentence\r\n",
        "text_batch = [\"I really didn't like this movie. Some of the actors were good, but overall the movie was boring.\"]\r\n",
        "encoding = tokenizer(text_batch, return_tensors='pt')\r\n",
        "input_ids = encoding['input_ids'].to(\"cuda\")\r\n",
        "attention_mask = encoding['attention_mask'].to(\"cuda\")\r\n",
        "\r\n",
        "# true class is positive - 1\r\n",
        "true_class = 1\r\n",
        "\r\n",
        "# generate an explanation for the input\r\n",
        "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\r\n",
        "# normalize scores\r\n",
        "expl = (expl - expl.min()) / (expl.max() - expl.min())\r\n",
        "\r\n",
        "# get the model classification\r\n",
        "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\r\n",
        "classification = output.argmax(dim=-1).item()\r\n",
        "# get class name\r\n",
        "class_name = classifications[classification]\r\n",
        "# if the classification is negative, higher explanation scores are more negative\r\n",
        "# flip for visualization\r\n",
        "if class_name == \"NEGATIVE\":\r\n",
        "  expl *= (-1)\r\n",
        "\r\n",
        "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\r\n",
        "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\r\n",
        "vis_data_records = [visualization.VisualizationDataRecord(\r\n",
        "                                expl,\r\n",
        "                                output[0][classification],\r\n",
        "                                classification,\r\n",
        "                                true_class,\r\n",
        "                                true_class,\r\n",
        "                                1,       \r\n",
        "                                tokens,\r\n",
        "                                1)]\r\n",
        "visualization.visualize_text(vis_data_records)"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[('[CLS]', -0.0), ('i', -0.20092236995697021), ('really', -0.19856177270412445), ('didn', -0.2984919846057892), (\"'\", -0.018205545842647552), ('t', -0.37614259123802185), ('like', -0.16261045634746552), ('this', -0.19905252754688263), ('movie', -0.05162941664457321), ('.', -0.030849019065499306), ('some', -0.03427424281835556), ('of', -0.01179436780512333), ('the', -0.03457795828580856), ('actors', -0.10515207052230835), ('were', -0.025286149233579636), ('good', -0.04380534216761589), (',', -0.01580347865819931), ('but', -0.06978689879179001), ('overall', -0.34404104948043823), ('the', -0.2279864102602005), ('movie', -0.06604310870170593), ('was', -0.6306629776954651), ('boring', -1.0), ('.', -0.04221245273947716), ('[SEP]', -0.0635751336812973)]\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> really                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> didn                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> '                    </font></mark><mark style=\"background-color: hsl(0, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> t                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> like                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> of                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> actors                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> good                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> overall                    </font></mark><mark style=\"background-color: hsl(0, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> boring                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>0 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1.00</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> really                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> didn                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> '                    </font></mark><mark style=\"background-color: hsl(0, 75%, 85%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> t                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> like                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> this                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> some                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> of                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> actors                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> good                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> but                    </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> overall                    </font></mark><mark style=\"background-color: hsl(0, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(0, 75%, 75%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(0, 75%, 60%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> boring                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    }
  ]
}