{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparing TensorFlow (original) and PyTorch models\n",
    "\n",
    "You can use this small notebook to check the conversion of the model's weights from the TensorFlow model to the PyTorch model. In the following, we compare the weights of the last layer on a simple example (in `input.txt`) but both models returns all the hidden layers so you can check every stage of the model.\n",
    "\n",
    "To run this notebook, follow these instructions:\n",
    "- make sure that your Python environment has both TensorFlow and PyTorch installed,\n",
    "- download the original TensorFlow implementation,\n",
    "- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,\n",
    "- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.\n",
    "\n",
    "If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:56:48.412622Z",
     "start_time": "2018-11-15T14:56:48.400110Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1/ TensorFlow code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:56:49.483829Z",
     "start_time": "2018-11-15T14:56:49.471296Z"
    }
   },
   "outputs": [],
   "source": [
    "original_tf_inplem_dir = \"./tensorflow_code/\"\n",
    "model_dir = \"../google_models/uncased_L-12_H-768_A-12/\"\n",
    "\n",
    "vocab_file = model_dir + \"vocab.txt\"\n",
    "bert_config_file = model_dir + \"bert_config.json\"\n",
    "init_checkpoint = model_dir + \"bert_model.ckpt\"\n",
    "\n",
    "input_file = \"./samples/input.txt\"\n",
    "max_seq_length = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:57:51.597932Z",
     "start_time": "2018-11-15T14:57:51.549466Z"
    }
   },
   "outputs": [
    {
     "ename": "DuplicateFlagError",
     "evalue": "The flag 'input_file' is defined twice. First from *, Second from *.  Description from first occurrence: (no help available)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mDuplicateFlagError\u001b[0m                        Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-86ecffb49060>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec_from_file_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_tf_inplem_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'/extract_features_tensorflow.py'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_from_spec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexec_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'extract_features_tensorflow'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
      "\u001b[0;32m~/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/tensorflow_code/extract_features_tensorflow.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0mFLAGS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFLAGS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"input_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"output_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/tensorflow/python/platform/flags.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     56\u001b[0m           \u001b[0;34m'Use of the keyword argument names (flag_name, default_value, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m           'docstring) is deprecated, please use (name, default, help) instead.')\n\u001b[0;32m---> 58\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0moriginal_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m   \u001b[0;32mreturn\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_decorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moriginal_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_string\u001b[0;34m(name, default, help, flag_values, **args)\u001b[0m\n\u001b[1;32m    239\u001b[0m   \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    240\u001b[0m   \u001b[0mserializer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentSerializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m   \u001b[0mDEFINE\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparser\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mserializer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE\u001b[0;34m(parser, name, default, help, flag_values, serializer, module_name, **args)\u001b[0m\n\u001b[1;32m     80\u001b[0m   \"\"\"\n\u001b[1;32m     81\u001b[0m   DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),\n\u001b[0;32m---> 82\u001b[0;31m               flag_values, module_name)\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_flag\u001b[0;34m(flag, flag_values, module_name)\u001b[0m\n\u001b[1;32m    102\u001b[0m   \u001b[0;31m# Copying the reference to flag_values prevents pychecker warnings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m   \u001b[0mfv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m   \u001b[0mfv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    105\u001b[0m   \u001b[0;31m# Tell flag_values who's defining the flag.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_flagvalues.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, name, flag)\u001b[0m\n\u001b[1;32m    427\u001b[0m         \u001b[0;31m# module is simply being imported a subsequent time.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    428\u001b[0m         \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0m_exceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDuplicateFlagError\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_flag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    430\u001b[0m     \u001b[0mshort_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshort_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    431\u001b[0m     \u001b[0;31m# If a new flag overrides an old one, we need to cleanup the old flag's\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDuplicateFlagError\u001b[0m: The flag 'input_file' is defined twice. First from *, Second from *.  Description from first occurrence: (no help available)"
     ]
    }
   ],
   "source": [
    "import importlib.util\n",
    "import sys\n",
    "\n",
    "spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features_tensorflow.py')\n",
    "module = importlib.util.module_from_spec(spec)\n",
    "spec.loader.exec_module(module)\n",
    "sys.modules['extract_features_tensorflow'] = module\n",
    "\n",
    "from extract_features_tensorflow import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:58:05.650987Z",
     "start_time": "2018-11-15T14:58:05.541620Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Example ***\n",
      "INFO:tensorflow:unique_id: 0\n",
      "INFO:tensorflow:tokens: [CLS] who was jim henson ? [SEP] jim henson was a puppet ##eer [SEP]\n",
      "INFO:tensorflow:input_ids: 101 2040 2001 3958 27227 1029 102 3958 27227 2001 1037 13997 11510 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
     ]
    }
   ],
   "source": [
    "layer_indexes = list(range(12))\n",
    "bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "    vocab_file=vocab_file, do_lower_case=True)\n",
    "examples = read_examples(input_file)\n",
    "\n",
    "features = convert_examples_to_features(\n",
    "    examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)\n",
    "unique_id_to_feature = {}\n",
    "for feature in features:\n",
    "    unique_id_to_feature[feature.unique_id] = feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:58:11.562443Z",
     "start_time": "2018-11-15T14:58:08.036485Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x11ea7f1e0>) includes params argument, but params are not passed to Estimator.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121b163c8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
      "WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n",
      "INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
      "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
     ]
    }
   ],
   "source": [
    "is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n",
    "run_config = tf.contrib.tpu.RunConfig(\n",
    "    master=None,\n",
    "    tpu_config=tf.contrib.tpu.TPUConfig(\n",
    "        num_shards=1,\n",
    "        per_host_input_for_training=is_per_host))\n",
    "\n",
    "model_fn = model_fn_builder(\n",
    "    bert_config=bert_config,\n",
    "    init_checkpoint=init_checkpoint,\n",
    "    layer_indexes=layer_indexes,\n",
    "    use_tpu=False,\n",
    "    use_one_hot_embeddings=False)\n",
    "\n",
    "# If TPU is not available, this will fall back to normal Estimator on CPU\n",
    "# or GPU.\n",
    "estimator = tf.contrib.tpu.TPUEstimator(\n",
    "    use_tpu=False,\n",
    "    model_fn=model_fn,\n",
    "    config=run_config,\n",
    "    predict_batch_size=1)\n",
    "\n",
    "input_fn = input_fn_builder(\n",
    "    features=features, seq_length=max_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:58:21.736543Z",
     "start_time": "2018-11-15T14:58:16.723829Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9, running initialization to predict.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Running infer on CPU\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "extracting layer 0\n",
      "extracting layer 1\n",
      "extracting layer 2\n",
      "extracting layer 3\n",
      "extracting layer 4\n",
      "extracting layer 5\n",
      "extracting layer 6\n",
      "extracting layer 7\n",
      "extracting layer 8\n",
      "extracting layer 9\n",
      "extracting layer 10\n",
      "extracting layer 11\n",
      "INFO:tensorflow:prediction_loop marked as finished\n",
      "INFO:tensorflow:prediction_loop marked as finished\n"
     ]
    }
   ],
   "source": [
    "tensorflow_all_out = []\n",
    "for result in estimator.predict(input_fn, yield_single_examples=True):\n",
    "    unique_id = int(result[\"unique_id\"])\n",
    "    feature = unique_id_to_feature[unique_id]\n",
    "    output_json = collections.OrderedDict()\n",
    "    output_json[\"linex_index\"] = unique_id\n",
    "    tensorflow_all_out_features = []\n",
    "    # for (i, token) in enumerate(feature.tokens):\n",
    "    all_layers = []\n",
    "    for (j, layer_index) in enumerate(layer_indexes):\n",
    "        print(\"extracting layer {}\".format(j))\n",
    "        layer_output = result[\"layer_output_%d\" % j]\n",
    "        layers = collections.OrderedDict()\n",
    "        layers[\"index\"] = layer_index\n",
    "        layers[\"values\"] = layer_output\n",
    "        all_layers.append(layers)\n",
    "    tensorflow_out_features = collections.OrderedDict()\n",
    "    tensorflow_out_features[\"layers\"] = all_layers\n",
    "    tensorflow_all_out_features.append(tensorflow_out_features)\n",
    "\n",
    "    output_json[\"features\"] = tensorflow_all_out_features\n",
    "    tensorflow_all_out.append(output_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:58:23.970714Z",
     "start_time": "2018-11-15T14:58:23.931930Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "odict_keys(['linex_index', 'features'])\n",
      "number of tokens 1\n",
      "number of layers 12\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(128, 768)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(len(tensorflow_all_out))\n",
    "print(len(tensorflow_all_out[0]))\n",
    "print(tensorflow_all_out[0].keys())\n",
    "print(\"number of tokens\", len(tensorflow_all_out[0]['features']))\n",
    "print(\"number of layers\", len(tensorflow_all_out[0]['features'][0]['layers']))\n",
    "tensorflow_all_out[0]['features'][0]['layers'][0]['values'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T14:58:25.547012Z",
     "start_time": "2018-11-15T14:58:25.516076Z"
    }
   },
   "outputs": [],
   "source": [
    "tensorflow_outputs = list(tensorflow_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2/ PyTorch code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('./examples')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:03:49.528679Z",
     "start_time": "2018-11-15T15:03:49.497697Z"
    }
   },
   "outputs": [],
   "source": [
    "import extract_features\n",
    "import pytorch_pretrained_bert as ppb\n",
    "from extract_features import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:18.001177Z",
     "start_time": "2018-11-15T15:21:17.970369Z"
    }
   },
   "outputs": [],
   "source": [
    "init_checkpoint_pt = \"../../google_models/uncased_L-12_H-768_A-12/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:20.893669Z",
     "start_time": "2018-11-15T15:21:18.786623Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
      "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling -   Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
       "  (embeddings): BertEmbeddings(\n",
       "    (word_embeddings): Embedding(30522, 768)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
       "    (LayerNorm): BertLayerNorm()\n",
       "    (dropout): Dropout(p=0.1)\n",
       "  )\n",
       "  (encoder): BertEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (1): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (2): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (3): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (4): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (5): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (6): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (7): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (8): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (9): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (10): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (11): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): BertPooler(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "model = ppb.BertModel.from_pretrained(init_checkpoint_pt)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:26.963427Z",
     "start_time": "2018-11-15T15:21:26.922494Z"
    },
    "code_folding": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
       "  (embeddings): BertEmbeddings(\n",
       "    (word_embeddings): Embedding(30522, 768)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
       "    (LayerNorm): BertLayerNorm()\n",
       "    (dropout): Dropout(p=0.1)\n",
       "  )\n",
       "  (encoder): BertEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (1): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (2): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (3): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (4): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (5): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (6): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (7): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (8): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (9): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (10): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "      (11): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): BertLayerNorm()\n",
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): BertPooler(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
    "all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
    "all_input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n",
    "all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n",
    "\n",
    "eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_type_ids, all_example_index)\n",
    "eval_sampler = SequentialSampler(eval_data)\n",
    "eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n",
    "\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:30.718724Z",
     "start_time": "2018-11-15T15:21:30.329205Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,  2001,\n",
      "          1037, 13997, 11510,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0]])\n",
      "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0]])\n",
      "tensor([0])\n",
      "layer 0 0\n",
      "layer 1 1\n",
      "layer 2 2\n",
      "layer 3 3\n",
      "layer 4 4\n",
      "layer 5 5\n",
      "layer 6 6\n",
      "layer 7 7\n",
      "layer 8 8\n",
      "layer 9 9\n",
      "layer 10 10\n",
      "layer 11 11\n"
     ]
    }
   ],
   "source": [
    "layer_indexes = list(range(12))\n",
    "\n",
    "pytorch_all_out = []\n",
    "for input_ids, input_mask, input_type_ids, example_indices in eval_dataloader:\n",
    "    print(input_ids)\n",
    "    print(input_mask)\n",
    "    print(example_indices)\n",
    "    input_ids = input_ids.to(device)\n",
    "    input_mask = input_mask.to(device)\n",
    "\n",
    "    all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n",
    "\n",
    "    for b, example_index in enumerate(example_indices):\n",
    "        feature = features[example_index.item()]\n",
    "        unique_id = int(feature.unique_id)\n",
    "        # feature = unique_id_to_feature[unique_id]\n",
    "        output_json = collections.OrderedDict()\n",
    "        output_json[\"linex_index\"] = unique_id\n",
    "        all_out_features = []\n",
    "        # for (i, token) in enumerate(feature.tokens):\n",
    "        all_layers = []\n",
    "        for (j, layer_index) in enumerate(layer_indexes):\n",
    "            print(\"layer\", j, layer_index)\n",
    "            layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()\n",
    "            layer_output = layer_output[b]\n",
    "            layers = collections.OrderedDict()\n",
    "            layers[\"index\"] = layer_index\n",
    "            layer_output = layer_output\n",
    "            layers[\"values\"] = layer_output if not isinstance(layer_output, (int, float)) else [layer_output]\n",
    "            all_layers.append(layers)\n",
    "\n",
    "            out_features = collections.OrderedDict()\n",
    "            out_features[\"layers\"] = all_layers\n",
    "            all_out_features.append(out_features)\n",
    "        output_json[\"features\"] = all_out_features\n",
    "        pytorch_all_out.append(output_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:35.703615Z",
     "start_time": "2018-11-15T15:21:35.666150Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "odict_keys(['linex_index', 'features'])\n",
      "number of tokens 1\n",
      "number of layers 12\n",
      "hidden_size 128\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(128, 768)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(len(pytorch_all_out))\n",
    "print(len(pytorch_all_out[0]))\n",
    "print(pytorch_all_out[0].keys())\n",
    "print(\"number of tokens\", len(pytorch_all_out))\n",
    "print(\"number of layers\", len(pytorch_all_out[0]['features'][0]['layers']))\n",
    "print(\"hidden_size\", len(pytorch_all_out[0]['features'][0]['layers'][0]['values']))\n",
    "pytorch_all_out[0]['features'][0]['layers'][0]['values'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:36.999073Z",
     "start_time": "2018-11-15T15:21:36.966762Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 768)\n",
      "(128, 768)\n"
     ]
    }
   ],
   "source": [
    "pytorch_outputs = list(pytorch_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)\n",
    "print(pytorch_outputs[0].shape)\n",
    "print(pytorch_outputs[1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:37.936522Z",
     "start_time": "2018-11-15T15:21:37.905269Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 768)\n",
      "(128, 768)\n"
     ]
    }
   ],
   "source": [
    "print(tensorflow_outputs[0].shape)\n",
    "print(tensorflow_outputs[1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3/ Comparing the standard deviation on the last layer of both models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:39.437137Z",
     "start_time": "2018-11-15T15:21:39.406150Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-15T15:21:40.181870Z",
     "start_time": "2018-11-15T15:21:40.137023Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape tensorflow layer, shape pytorch layer, standard deviation\n",
      "((128, 768), (128, 768), 1.5258875e-07)\n",
      "((128, 768), (128, 768), 2.342731e-07)\n",
      "((128, 768), (128, 768), 2.801949e-07)\n",
      "((128, 768), (128, 768), 3.5904986e-07)\n",
      "((128, 768), (128, 768), 4.2842768e-07)\n",
      "((128, 768), (128, 768), 5.127951e-07)\n",
      "((128, 768), (128, 768), 6.14668e-07)\n",
      "((128, 768), (128, 768), 7.063922e-07)\n",
      "((128, 768), (128, 768), 7.906173e-07)\n",
      "((128, 768), (128, 768), 8.475192e-07)\n",
      "((128, 768), (128, 768), 8.975489e-07)\n",
      "((128, 768), (128, 768), 4.1671223e-07)\n"
     ]
    }
   ],
   "source": [
    "print('shape tensorflow layer, shape pytorch layer, standard deviation')\n",
    "print('\\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,\n",
    "                          np.array(pytorch_outputs[i]).shape, \n",
    "                          np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(12))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "48px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
