{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TextAttack with Custom Dataset and Word Embedding.\n",
    "\n",
    "This tutorial will show you how to use textattack with any dataset and word embedding you may want to use\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QData/TextAttack/blob/master/docs/2notebook/4_Custom_Datasets_Word_Embedding.ipynb)\n",
    "\n",
    "[![View Source on GitHub](https://img.shields.io/badge/github-view%20source-black.svg)](https://github.com/QData/TextAttack/blob/master/docs/2notebook/4_Custom_Datasets_Word_Embedding.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_WVki6Bvbjur"
   },
   "source": [
    "Please remember to run  **pip3 install textattack[tensorflow]**  in your notebook enviroment before the following codes:\n",
    "\n",
    "## **Importing the Model**\n",
    "\n",
    "We start by choosing a pretrained model we want to attack. In this example we will use the albert base v2 model from HuggingFace. This model was trained with data from imbd, a set of movie reviews with either positive or negative labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: textattack in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (0.3.0)\n",
      "Requirement already satisfied: editdistance in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.5.3)\n",
      "Requirement already satisfied: num2words in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.5.10)\n",
      "Requirement already satisfied: flair in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.6.1.post1)\n",
      "Requirement already satisfied: bert-score>=0.3.5 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.3.7)\n",
      "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (4.49.0)\n",
      "Requirement already satisfied: nltk in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.5)\n",
      "Requirement already satisfied: transformers>=3.3.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (4.1.1)\n",
      "Requirement already satisfied: word2number in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1)\n",
      "Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.0.12)\n",
      "Requirement already satisfied: terminaltables in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (3.1.0)\n",
      "Requirement already satisfied: lru-dict in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1.6)\n",
      "Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.1.3)\n",
      "Requirement already satisfied: language-tool-python in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (2.4.7)\n",
      "Requirement already satisfied: more-itertools in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (8.6.0)\n",
      "Requirement already satisfied: pandas>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.2.0)\n",
      "Requirement already satisfied: torch!=1.8,>=1.7.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.7.1)\n",
      "Requirement already satisfied: lemminflect in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (0.2.1)\n",
      "Requirement already satisfied: numpy<1.19.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.18.5)\n",
      "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.7.1)\n",
      "Requirement already satisfied: scipy==1.4.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from textattack) (1.4.1)\n",
      "Requirement already satisfied: docopt>=0.6.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from num2words->textattack) (0.6.2)\n",
      "Requirement already satisfied: gdown in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.12.2)\n",
      "Requirement already satisfied: tabulate in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.8.7)\n",
      "Requirement already satisfied: konoha<5.0.0,>=4.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (4.6.2)\n",
      "Requirement already satisfied: python-dateutil>=2.6.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (2.8.1)\n",
      "Requirement already satisfied: ftfy in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (5.8)\n",
      "Requirement already satisfied: regex in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (2020.11.13)\n",
      "Requirement already satisfied: segtok>=1.5.7 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.5.10)\n",
      "Requirement already satisfied: gensim>=3.4.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.8.3)\n",
      "Requirement already satisfied: matplotlib>=2.2.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (3.3.3)\n",
      "Requirement already satisfied: bpemb>=0.3.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.3.2)\n",
      "Requirement already satisfied: janome in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.4.1)\n",
      "Requirement already satisfied: langdetect in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.0.8)\n",
      "Requirement already satisfied: hyperopt>=0.1.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.2.5)\n",
      "Requirement already satisfied: sentencepiece!=0.1.92 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.1.94)\n",
      "Requirement already satisfied: lxml in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (4.6.2)\n",
      "Requirement already satisfied: deprecated>=1.2.4 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.2.10)\n",
      "Requirement already satisfied: scikit-learn>=0.21.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.24.0)\n",
      "Requirement already satisfied: mpld3==0.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (0.3)\n",
      "Requirement already satisfied: sqlitedict>=1.6.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from flair->textattack) (1.7.0)\n",
      "Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from bert-score>=0.3.5->textattack) (2.25.1)\n",
      "Requirement already satisfied: click in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from nltk->textattack) (7.1.2)\n",
      "Requirement already satisfied: joblib in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from nltk->textattack) (1.0.0)\n",
      "Requirement already satisfied: packaging in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (20.8)\n",
      "Requirement already satisfied: sacremoses in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (0.0.43)\n",
      "Requirement already satisfied: tokenizers==0.9.4 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from transformers>=3.3.0->textattack) (0.9.4)\n",
      "Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (0.70.11.1)\n",
      "Requirement already satisfied: pyarrow>=0.17.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (2.0.0)\n",
      "Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (2.0.0)\n",
      "Requirement already satisfied: dill in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from datasets->textattack) (0.3.3)\n",
      "Requirement already satisfied: pytz>=2017.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from pandas>=1.0.1->textattack) (2020.5)\n",
      "Requirement already satisfied: typing-extensions in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from torch!=1.8,>=1.7.0->textattack) (3.7.4.3)\n",
      "Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gdown->flair->textattack) (1.15.0)\n",
      "Requirement already satisfied: overrides==3.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from konoha<5.0.0,>=4.0.0->flair->textattack) (3.0.0)\n",
      "Requirement already satisfied: wcwidth in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from ftfy->flair->textattack) (0.2.5)\n",
      "Requirement already satisfied: smart-open>=1.8.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from gensim>=3.4.0->flair->textattack) (4.1.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (1.3.1)\n",
      "Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (0.10.0)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (8.0.1)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from matplotlib>=2.2.3->flair->textattack) (2.4.7)\n",
      "Requirement already satisfied: future in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (0.18.2)\n",
      "Requirement already satisfied: networkx>=2.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (2.5)\n",
      "Requirement already satisfied: cloudpickle in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from hyperopt>=0.1.1->flair->textattack) (1.6.0)\n",
      "Requirement already satisfied: wrapt<2,>=1.10 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from deprecated>=1.2.4->flair->textattack) (1.12.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from scikit-learn>=0.21.3->flair->textattack) (2.1.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (2020.12.5)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (4.0.0)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (2.10)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from requests->bert-score>=0.3.5->textattack) (1.26.2)\n",
      "Requirement already satisfied: decorator>=4.3.0 in /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages (from networkx>=2.2->hyperopt>=0.1.1->flair->textattack) (4.4.2)\n",
      "\u001b[33mWARNING: You are using pip version 20.1.1; however, version 21.1.3 is available.\n",
      "You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.7/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip3 install textattack[tensorflow]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 585,
     "referenced_widgets": [
      "1905ff29aaa242a88dc93f3247065364",
      "917713cc9b1344c7a7801144f04252bc",
      "b65d55c5b9f445a6bfd585f6237d22ca",
      "38b56a89b2ae4a8ca93c03182db26983",
      "26082a081d1c49bd907043a925cf88df",
      "1c3edce071ad4a2a99bf3e34ea40242c",
      "f9c265a003444a03bde78e18ed3f5a7e",
      "3cb9eb594c8640ffbfd4a0b1139d571a",
      "7d29511ba83a4eaeb4a2e5cd89ca1990",
      "136f44f7b8fa433ebff6d0a534c0588b",
      "2658e486ee77468a99ab4edc7b5191d8",
      "39bfd8c439b847e4bdfeee6e66ae86f3",
      "7ca4ce3d902d42758eb1fc02b9b211d3",
      "222cacceca11402db10ff88a92a2d31d",
      "108d2b83dff244edbebf4f8909dce789",
      "c06317aaf0064cb9b6d86d032821a8e2",
      "c18ac12f8c6148b9aa2d69885351fbcb",
      "b11ad31ee69441df8f0447a4ae62ce75",
      "a7e846fdbda740a38644e28e11a67707",
      "b38d5158e5584461bfe0b2f8ed3b0dc2",
      "3bdef9b4157e41f3a01f25b07e8efa48",
      "69e19afa8e2c49fbab0e910a5929200f",
      "2627a092f0c041c0a5f67451b1bd8b2b",
      "1780cb5670714c0a9b7a94b92ffc1819",
      "1ac87e683d2e4951ac94e25e8fe88d69",
      "02daee23726349a69d4473814ede81c3",
      "1fac551ad9d840f38b540ea5c364af70",
      "1027e6f245924195a930aca8c3844f44",
      "5b863870023e4c438ed75d830c13c5ac",
      "9ec55c6e2c4e40daa284596372728213",
      "5e2d17ed769d496db38d053cc69a914c",
      "dedaafae3bcc47f59b7d9b025b31fd0c",
      "8c2f5cda0ae9472fa7ec2b864d0bdc0e",
      "2a35d22dd2604950bae55c7c51f4af2c",
      "4c23ca1540fd48b1ac90d9365c9c6427",
      "3e4881a27c36472ab4c24167da6817cf",
      "af32025d22534f9da9e769b02f5e6422",
      "7af34c47299f458789e03987026c3519",
      "ed0ab8c7456a42618d6cbf6fd496b7b3",
      "25fc5fdac77247f9b029ada61af630fd"
     ]
    },
    "id": "4ZEnCFoYv-y7",
    "outputId": "c6c57cb9-6d6e-4efd-988f-c794356d4719"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae9361c362a14921a5dfac77403ba997",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=727.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ebddb1a98c94467af5bbc97661ab2e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=46747112.0, style=ProgressStyle(descrip…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a9ff6dbfc444f2ea483f16055f4cf4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1345259be09b41b1adefed6763766ee2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=156.0, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d21cada2e5bd47f8bee3facdf0b382a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import transformers\n",
    "from textattack.models.wrappers import HuggingFaceModelWrapper\n",
    "\n",
    "# https://huggingface.co/textattack\n",
    "model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"textattack/albert-base-v2-imdb\"\n",
    ")\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"textattack/albert-base-v2-imdb\")\n",
    "# We wrap the model so it can be used by textattack\n",
    "model_wrapper = HuggingFaceModelWrapper(model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D61VLa8FexyK"
   },
   "source": [
    "## **Creating A Custom Dataset**\n",
    "\n",
    "Textattack takes in dataset in the form of a list of tuples. The tuple can be in the form of (\"string\", label) or (\"string\", label, label). In this case we will use former one, since we want to create a custom movie review dataset with label 0 representing a positive review, and label 1 representing a negative review.\n",
    "\n",
    "For simplicity, I created a dataset consisting of 4 reviews, the 1st and 4th review have \"correct\" labels, while the 2nd and 3rd review have \"incorrect\" labels. We will see how this impacts perturbation later in this tutorial.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "nk_MUu5Duf1V"
   },
   "outputs": [],
   "source": [
    "# dataset: An iterable of (text, ground_truth_output) pairs.\n",
    "# 0 means the review is negative\n",
    "# 1 means the review is positive\n",
    "custom_dataset = [\n",
    "    (\"I hate this movie\", 0),  # A negative comment, with a negative label\n",
    "    (\"I hate this movie\", 1),  # A negative comment, with a positive label\n",
    "    (\"I love this movie\", 0),  # A positive comment, with a negative label\n",
    "    (\"I love this movie\", 1),  # A positive comment, with a positive label\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ijVmi6PbiUYZ"
   },
   "source": [
    "## **Creating An Attack**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-iEH_hf6iMEw",
    "outputId": "0c836c5b-ddd5-414d-f73d-da04067054d8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "textattack: Unknown if model of class <class 'transformers.models.albert.modeling_albert.AlbertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.\n"
     ]
    }
   ],
   "source": [
    "from textattack import Attack\n",
    "from textattack.search_methods import GreedySearch\n",
    "from textattack.constraints.pre_transformation import (\n",
    "    RepeatModification,\n",
    "    StopwordModification,\n",
    ")\n",
    "from textattack.goal_functions import UntargetedClassification\n",
    "from textattack.transformations import WordSwapEmbedding\n",
    "from textattack.constraints.pre_transformation import RepeatModification\n",
    "from textattack.constraints.pre_transformation import StopwordModification\n",
    "\n",
    "# We'll use untargeted classification as the goal function.\n",
    "goal_function = UntargetedClassification(model_wrapper)\n",
    "# We'll to use our WordSwapEmbedding as the attack transformation.\n",
    "transformation = WordSwapEmbedding()\n",
    "# We'll constrain modification of already modified indices and stopwords\n",
    "constraints = [RepeatModification(), StopwordModification()]\n",
    "# We'll use the Greedy search method\n",
    "search_method = GreedySearch()\n",
    "# Now, let's make the attack from the 4 components:\n",
    "attack = Attack(goal_function, constraints, transformation, search_method)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4hUA8ntnfJzH"
   },
   "source": [
    "## **Attack Results With Custom Dataset**\n",
    "\n",
    "As you can see, the attack fools the model by changing a few words in the 1st and 4th review.\n",
    "\n",
    "The attack skipped the 2nd and and 3rd review because since it they were labeled incorrectly, they managed to fool the model without any modifications."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-ivoHEOXfIfN",
    "outputId": "9ec660b6-44fc-4354-9dd1-1641b6f4c986"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[91m0 (99%)\u001b[0m --> \u001b[92m1 (81%)\u001b[0m\n",
      "\n",
      "\u001b[91mI\u001b[0m \u001b[91mhate\u001b[0m this \u001b[91mmovie\u001b[0m\n",
      "\n",
      "\u001b[92mdid\u001b[0m \u001b[92mhateful\u001b[0m this \u001b[92mfootage\u001b[0m\n",
      "\u001b[91m0 (99%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "I hate this movie\n",
      "\u001b[92m1 (96%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "I love this movie\n",
      "\u001b[92m1 (96%)\u001b[0m --> \u001b[91m0 (99%)\u001b[0m\n",
      "\n",
      "I \u001b[92mlove\u001b[0m this movie\n",
      "\n",
      "I \u001b[91miove\u001b[0m this movie\n"
     ]
    }
   ],
   "source": [
    "for example, label in custom_dataset:\n",
    "    result = attack.attack(example, label)\n",
    "    print(result.__str__(color_method=\"ansi\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "foFZmk8vY5z0"
   },
   "source": [
    "## **Creating A Custom Word Embedding**\n",
    "\n",
    "In textattack, a pre-trained word embedding is necessary in transformation in order to find synonym replacements, and in constraints to check the semantic validity of the transformation. To use custom pre-trained word embeddings, you can either create a new class that inherits the AbstractWordEmbedding class, or use the WordEmbedding class which takes in 4 parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "owj_jMHRxEF5"
   },
   "outputs": [],
   "source": [
    "from textattack.shared import WordEmbedding\n",
    "\n",
    "embedding_matrix = [\n",
    "    [1.0],\n",
    "    [2.0],\n",
    "    [3.0],\n",
    "    [4.0],\n",
    "]  # 2-D array of shape N x D where N represents size of vocab and D is the dimension of embedding vectors.\n",
    "word2index = {\n",
    "    \"hate\": 0,\n",
    "    \"despise\": 1,\n",
    "    \"like\": 2,\n",
    "    \"love\": 3,\n",
    "}  # dictionary that maps word to its index with in the embedding matrix.\n",
    "index2word = {\n",
    "    0: \"hate\",\n",
    "    1: \"despise\",\n",
    "    2: \"like\",\n",
    "    3: \"love\",\n",
    "}  # dictionary that maps index to its word.\n",
    "nn_matrix = [\n",
    "    [0, 1, 2, 3],\n",
    "    [1, 0, 2, 3],\n",
    "    [2, 1, 3, 0],\n",
    "    [3, 2, 1, 0],\n",
    "]  # 2-D integer array of shape N x K where N represents size of vocab and K is the top-K nearest neighbours.\n",
    "\n",
    "embedding = WordEmbedding(embedding_matrix, word2index, index2word, nn_matrix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s9ZEV_ykhmBn"
   },
   "source": [
    "## **Attack Results With Custom Dataset and Word Embedding**\n",
    "\n",
    "Now if we run the attack again with the custom word embedding, you will notice the modifications are limited to the vocab provided by our custom word embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gZ98UZ6I5sIn",
    "outputId": "59a653cb-85cb-46b5-d81b-c1a05ebe8a3e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[91m0 (99%)\u001b[0m --> \u001b[92m1 (98%)\u001b[0m\n",
      "\n",
      "I \u001b[91mhate\u001b[0m this movie\n",
      "\n",
      "I \u001b[92mlike\u001b[0m this movie\n",
      "\u001b[91m0 (99%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "I hate this movie\n",
      "\u001b[92m1 (96%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "I love this movie\n",
      "\u001b[92m1 (96%)\u001b[0m --> \u001b[91m0 (99%)\u001b[0m\n",
      "\n",
      "I \u001b[92mlove\u001b[0m this movie\n",
      "\n",
      "I \u001b[91mdespise\u001b[0m this movie\n"
     ]
    }
   ],
   "source": [
    "from textattack.attack_results import SuccessfulAttackResult\n",
    "\n",
    "transformation = WordSwapEmbedding(3, embedding)\n",
    "\n",
    "attack = Attack(goal_function, constraints, transformation, search_method)\n",
    "\n",
    "# here is a legacy code piece showing how the attack runs in  details\n",
    "for example, label in custom_dataset:\n",
    "    result = attack.attack(example, label)\n",
    "    print(result.__str__(color_method=\"ansi\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here is currently recommendated API-centric way to use customized attack\n",
    "\n",
    "from textattack.loggers import CSVLogger  # tracks a dataframe for us.\n",
    "from textattack.attack_results import SuccessfulAttackResult\n",
    "from textattack import Attacker, AttackArgs\n",
    "\n",
    "attack_args = AttackArgs(\n",
    "    num_successful_examples=5, log_to_csv=\"results.csv\", csv_coloring_style=\"html\"\n",
    ")\n",
    "attacker = Attacker(attack, custom_dataset, attack_args)\n",
    "\n",
    "attack_results = attacker.attack_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we visualize the attack results\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.options.display.max_colwidth = (\n",
    "    480  # increase colum width so we can actually read the examples\n",
    ")\n",
    "\n",
    "logger = CSVLogger(color_method=\"html\")\n",
    "\n",
    "for result in attack_results:\n",
    "    if isinstance(result, SuccessfulAttackResult):\n",
    "        logger.log_attack_result(result)\n",
    "\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "results = pd.DataFrame.from_records(logger.row_list)\n",
    "display(HTML(results[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Custom Data and Embedding with TextAttack.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
