{
 "cells": [
  {
    "cell_type": "code",
    "execution_count": 0,
    "metadata": {
      "cellView": "form",
      "colab": {
        "autoexec": {
          "startup": false,
          "wait_interval": 0
        }
      },
      "colab_type": "code",
      "id": "6uNrFWq5BRba"
    },
    "outputs": [],
    "source": [
      "#@title\n",
      "# Copyright 2018 Google LLC.\n",
      "\n",
      "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
      "# you may not use this file except in compliance with the License.\n",
      "# You may obtain a copy of the License at\n",
      "\n",
      "# https://www.apache.org/licenses/LICENSE-2.0\n",
      "\n",
      "# Unless required by applicable law or agreed to in writing, software\n",
      "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
      "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
      "# See the License for the specific language governing permissions and\n",
      "# limitations under the License."
    ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Your Own Visualizations!\n",
    "Instructions:\n",
    "1. Install tensor2tensor and train up a Transformer model following the instruction in the repository https://github.com/tensorflow/tensor2tensor.\n",
    "2. Update cell 3 to point to your checkpoint, it is currently set up to read from the default checkpoint location that would be created from following the instructions above.\n",
    "3. If you used custom hyper parameters then update cell 4.\n",
    "4. Run the notebook!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.bin import t2t_decoder  # To register the hparams set\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.visualization import attention\n",
    "from tensor2tensor.visualization import visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "require.config({\n",
       "  paths: {\n",
       "      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n",
       "  }\n",
       "});"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%javascript\n",
    "require.config({\n",
    "  paths: {\n",
    "      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min'\n",
    "  }\n",
    "});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HParams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PUT THE MODEL YOU WANT TO LOAD HERE!\n",
    "CHECKPOINT = os.path.expanduser('~/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HParams\n",
    "problem_name = 'translate_ende_wmt32k'\n",
    "data_dir = os.path.expanduser('~/t2t_data/')\n",
    "model_name = \"transformer\"\n",
    "hparams_set = \"transformer_base_single_gpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'eval'\n",
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n",
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n",
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n",
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n",
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n",
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n",
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_33708_512.bottom\n",
      "INFO:tensorflow:Transforming 'targets' with symbol_modality_33708_512.targets_bottom\n",
      "INFO:tensorflow:Building model body\n",
      "WARNING:tensorflow:From /tmp/t2t/tensor2tensor/layers/common_layers.py:512: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n",
      "INFO:tensorflow:Transforming body output with symbol_modality_33708_512.top\n",
      "WARNING:tensorflow:From /tmp/t2t/tensor2tensor/layers/common_layers.py:1707: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See tf.nn.softmax_cross_entropy_with_logits_v2.\n",
      "\n",
      "INFO:tensorflow:Greedy Decoding\n"
     ]
    }
   ],
   "source": [
    "visualizer = visualization.AttentionVisualizer(hparams_set, model_name, data_dir, problem_name, beam_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Restoring parameters from /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt-1\n"
     ]
    }
   ],
   "source": [
    "tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')\n",
    "\n",
    "sess = tf.train.MonitoredTrainingSession(\n",
    "    checkpoint_dir=CHECKPOINT,\n",
    "    save_summaries_secs=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Saving checkpoints for 1 into /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt.\n"
     ]
    }
   ],
   "source": [
    "input_sentence = \"I have two dogs.\"\n",
    "output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence)\n",
    "print(output_string)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpreting the Visualizations\n",
    "- The layers drop down allow you to view the different Transformer layers, 0-indexed of course.\n",
    "  - Tip: The first layer, last layer and 2nd to last layer are usually the most interpretable.\n",
    "- The attention dropdown allows you to select different pairs of encoder-decoder attentions:\n",
    "  - All: Shows all types of attentions together. NOTE: There is no relation between heads of the same color - between the decoder self attention and decoder-encoder attention since they do not share parameters.\n",
    "  - Input - Input: Shows only the encoder self-attention.\n",
    "  - Input - Output: Shows the decoder’s attention on the encoder. NOTE: Every decoder layer attends to the final layer of encoder so the visualization will show the attention on the final encoder layer regardless of what layer is selected in the drop down.\n",
    "  - Output - Output: Shows only the decoder self-attention. NOTE: The visualization might be slightly misleading in the first layer since the text shown is the target of the decoder, the input to the decoder at layer 0 is this text with a GO symbol prepreded.\n",
    "- The colored squares represent the different attention heads.\n",
    "  - You can hide or show a given head by clicking on it’s color.\n",
    "  - Double clicking a color will hide all other colors, double clicking on a color when it’s the only head showing will show all the heads again.\n",
    "- You can hover over a word to see the individual attention weights for just that position.\n",
    "  - Hovering over the words on the left will show what that position attended to.\n",
    "  - Hovering over the words on the right will show what positions attended to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attention.show(inp_text, out_text, *att_mats)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
