{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2019-present, Facebook, Inc.\n",
    "# All rights reserved.\n",
    "#\n",
    "# This source code is licensed under the license found in the\n",
    "# LICENSE file in the root directory of this source tree.\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Code to generate sentence representations from a pretrained model.\n",
    "# This can be used to initialize a cross-lingual classifier, for instance.\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "FAISS library was not found.\n",
      "FAISS not available. Switching to standard nearest neighbors search implementation.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "from src.utils import AttrDict\n",
    "from src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD\n",
    "from src.model.transformer import TransformerModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reload a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Supported languages: af, als, am, an, ang, ar, arz, ast, az, bar, be, bg, bn, br, bs, ca, ceb, ckb, cs, cy, da, de, el, en, eo, es, et, eu, fa, fi, fr, fy, ga, gan, gl, gu, he, hi, hr, hu, hy, ia, id, is, it, ja, jv, ka, kk, kn, ko, ku, la, lb, lt, lv, mk, ml, mn, mr, ms, my, nds, ne, nl, nn, no, oc, pl, pt, ro, ru, scn, sco, sh, si, simple, sk, sl, sq, sr, sv, sw, ta, te, th, tl, tr, tt, uk, ur, uz, vi, war, wuu, yi, zh, zh_classical, zh_min_nan, zh_yue\n"
     ]
    }
   ],
   "source": [
    "model_path = 'models/mlm_100_1280.pth'\n",
    "reloaded = torch.load(model_path)\n",
    "params = AttrDict(reloaded['params'])\n",
    "print(\"Supported languages: %s\" % \", \".join(params.lang2id.keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build dictionary / update parameters / build model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# build dictionary / update parameters\n",
    "dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])\n",
    "params.n_words = len(dico)\n",
    "params.bos_index = dico.index(BOS_WORD)\n",
    "params.eos_index = dico.index(EOS_WORD)\n",
    "params.pad_index = dico.index(PAD_WORD)\n",
    "params.unk_index = dico.index(UNK_WORD)\n",
    "params.mask_index = dico.index(MASK_WORD)\n",
    "\n",
    "# build model / reload weights\n",
    "model = TransformerModel(params, dico, True, True)\n",
    "model.eval()\n",
    "model.load_state_dict(reloaded['model'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Get sentence representations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sentences have to be in the BPE format, i.e. tokenized sentences on which you applied fastBPE.\n",
    "\n",
    "Below you can see an example for English, French, Spanish, German, Arabic and Chinese sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Below is one way to bpe-ize sentences\n",
    "codes = \"\" # path to the codes of the model\n",
    "fastbpe = os.path.join(os.getcwd(), 'tools/fastBPE/fast')\n",
    "\n",
    "def to_bpe(sentences):\n",
    "    # write sentences to tmp file\n",
    "    with open('/tmp/sentences.bpe', 'w') as fwrite:\n",
    "        for sent in sentences:\n",
    "            fwrite.write(sent + '\\n')\n",
    "    \n",
    "    # apply bpe to tmp file\n",
    "    os.system('%s applybpe /tmp/sentences.bpe /tmp/sentences %s' % (fastbpe, codes))\n",
    "    \n",
    "    # load bpe-ized sentences\n",
    "    sentences_bpe = []\n",
    "    with open('/tmp/sentences.bpe') as f:\n",
    "        for line in f:\n",
    "            sentences_bpe.append(line.rstrip())\n",
    "    \n",
    "    return sentences_bpe\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "once he had worn tr@@ end@@ y italian le@@ ather sho@@ es and je@@ ans from paris that had cost three hundred euros .\n",
      "\n",
      "Le français est la seule langue étrang@@ ère propo@@ sée dans le système é@@ duc@@ atif .\n",
      "\n",
      "El cad@@ mio produce efectos tó@@ x@@ icos en los organismos vivos , aun en concentra@@ ciones muy pequeñas .\n",
      "\n",
      "Nach dem Zweiten Weltkrieg verbre@@ it@@ ete sich Bon@@ sai als Hob@@ by in der ganzen Welt .\n",
      "\n",
      "وقد فاز في الانتخابات في الج@@ ولة الثانية من التص@@ ويت من قبل سيدي ولد الشيخ عبد الله ، مع أحمد ولد دا@@ دا@@ ه في المرتبة الثانية .\n",
      "\n",
      "羅@@ 伯特 · 皮@@ 爾 斯 生於 186@@ 3年 , 在 英國 曼@@ 徹@@ 斯特 學習 而 成為 一 位 工程@@ 師 . 193@@ 3年 , 皮@@ 爾@@ 斯 在 直@@ 布@@ 羅@@ 陀@@ 去世 .\n",
      "Number of out-of-vocab words: 0/144\n"
     ]
    }
   ],
   "source": [
    "# Below are already BPE-ized sentences\n",
    "\n",
    "# list of (sentences, lang)\n",
    "sentences = [\n",
    "    'once he had worn trendy italian leather shoes and jeans from paris that had cost three hundred euros .', # en\n",
    "    'Le français est la seule langue étrangère proposée dans le système éducatif .', # fr\n",
    "    'El cadmio produce efectos tóxicos en los organismos vivos , aun en concentraciones muy pequeñas .', # es\n",
    "    'Nach dem Zweiten Weltkrieg verbreitete sich Bonsai als Hobby in der ganzen Welt .', # de\n",
    "    'وقد فاز في الانتخابات في الجولة الثانية من التصويت من قبل سيدي ولد الشيخ عبد الله ، مع أحمد ولد داداه في المرتبة الثانية .', # ar\n",
    "    '羅伯特 · 皮爾 斯 生於 1863年 , 在 英國 曼徹斯特 學習 而 成為 一 位 工程師 . 1933年 , 皮爾斯 在 直布羅陀去世 .', # zh\n",
    "]\n",
    "\n",
    "# bpe-ize sentences\n",
    "sentences = to_bpe(sentences)\n",
    "print('\\n\\n'.join(sentences))\n",
    "\n",
    "# check how many tokens are OOV\n",
    "n_w = len([w for w in ' '.join(sentences).split()])\n",
    "n_oov = len([w for w in ' '.join(sentences).split() if w not in dico.word2id])\n",
    "print('Number of out-of-vocab words: %s/%s' % (n_oov, n_w))\n",
    "\n",
    "# add </s> sentence delimiters\n",
    "sentences = [(('</s> %s </s>' % sent.strip()).split()) for sent in sentences]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = len(sentences)\n",
    "slen = max([len(sent) for sent in sentences])\n",
    "\n",
    "word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)\n",
    "for i in range(len(sentences)):\n",
    "    sent = torch.LongTensor([dico.index(w) for w in sentences[i]])\n",
    "    word_ids[:len(sent), i] = sent\n",
    "\n",
    "lengths = torch.LongTensor([len(sent) for sent in sentences])\n",
    "                             \n",
    "# NOTE: No more language id (removed it in a later version)\n",
    "# langs = torch.LongTensor([params.lang2id[lang] for _, lang in sentences]).unsqueeze(0).expand(slen, bs) if params.n_langs > 1 else None\n",
    "langs = None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([38, 6, 1280])\n"
     ]
    }
   ],
   "source": [
    "tensor = model('fwd', x=word_ids, lengths=lengths, langs=langs, causal=False).contiguous()\n",
    "print(tensor.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The variable `tensor` is of shape `(sequence_length, batch_size, model_dimension)`.\n",
    "\n",
    "`tensor[0]` is a tensor of shape `(batch_size, model_dimension)` that corresponds to the first hidden state of the last layer of each sentence.\n",
    "\n",
    "This is this vector that we use to finetune on the GLUE and XNLI tasks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
