{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:11:32.524250Z",
     "start_time": "2021-11-18T20:11:32.495248Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.container { width:100% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:13:29.542150Z",
     "start_time": "2021-11-18T20:13:15.672637Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-11-18 14:13:15--  https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.114.128, 142.250.113.128, 142.251.46.144, ...\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.114.128|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 804479532 (767M) [application/octet-stream]\n",
      "Saving to: ‘deepmind_assets/language_perceiver_io_bytes.pickle’\n",
      "\n",
      "100%[======================================>] 804,479,532 64.4MB/s   in 13s    \n",
      "\n",
      "2021-11-18 14:13:29 (57.5 MB/s) - ‘deepmind_assets/language_perceiver_io_bytes.pickle’ saved [804479532/804479532]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# downlaod deepmind's pretrained language model\n",
    "!wget -O deepmind_assets/language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:14:59.769020Z",
     "start_time": "2021-11-18T20:14:30.279947Z"
    }
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'haiku'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-bb9b818c869b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"deepmind_assets/language_perceiver_io_bytes.pickle\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'haiku'"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"deepmind_assets/language_perceiver_io_bytes.pickle\", \"rb\") as f:\n",
    "    params = pickle.loads(f.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:16:10.038091Z",
     "start_time": "2021-11-18T20:16:10.033679Z"
    }
   },
   "outputs": [],
   "source": [
    "from perceiver_io.perceiver_lm import PerceiverLM\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:16:11.772225Z",
     "start_time": "2021-11-18T20:16:10.040343Z"
    }
   },
   "outputs": [],
   "source": [
    "model = PerceiverLM(vocab_size=262, \n",
    "                    max_seq_len=2048, \n",
    "                    embedding_dim=768, \n",
    "                    num_latents=256, \n",
    "                    latent_dim=1280, \n",
    "                    qk_out_dim=256, \n",
    "                    num_self_attn_per_block=26)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:17:18.029141Z",
     "start_time": "2021-11-18T20:17:17.350667Z"
    }
   },
   "outputs": [],
   "source": [
    "state_dict = {}\n",
    "model_enc_base = 'perceiver.encoder.'\n",
    "params_enc_base = 'perceiver_encoder/~/'\n",
    "\n",
    "state_dict['token_embedding.weight'] = params['embed']['embeddings']\n",
    "state_dict['decoder_token_bias'] = params['embedding_decoder']['bias']\n",
    "state_dict['position_embedding.weight'] = params['trainable_position_encoding']['pos_embs']\n",
    "state_dict['query_embedding.weight'] = params['basic_decoder/~/trainable_position_encoding']['pos_embs']\n",
    "state_dict[f'{model_enc_base}latents'] = params[f'{params_enc_base}trainable_position_encoding']['pos_embs']\n",
    "\n",
    "def copy_attention_params(model_base, params_base):\n",
    "    global state_dict\n",
    "    state_dict[f'{model_base}attention.q.weight'] = params[f'{params_base}attention/linear']['w'].T\n",
    "    state_dict[f'{model_base}attention.q.bias'] = params[f'{params_base}attention/linear']['b']\n",
    "    state_dict[f'{model_base}attention.k.weight'] = params[f'{params_base}attention/linear_1']['w'].T\n",
    "    state_dict[f'{model_base}attention.k.bias'] = params[f'{params_base}attention/linear_1']['b']\n",
    "    state_dict[f'{model_base}attention.v.weight'] = params[f'{params_base}attention/linear_2']['w'].T\n",
    "    state_dict[f'{model_base}attention.v.bias'] = params[f'{params_base}attention/linear_2']['b']\n",
    "    state_dict[f'{model_base}attention.projection.weight'] = params[f'{params_base}attention/linear_3']['w'].T\n",
    "    state_dict[f'{model_base}attention.projection.bias'] = params[f'{params_base}attention/linear_3']['b']\n",
    "\n",
    "    if 'self_attention' in params_base:\n",
    "        state_dict[f'{model_base}layer_norm.weight'] = params[f'{params_base}layer_norm']['scale']\n",
    "        state_dict[f'{model_base}layer_norm.bias'] = params[f'{params_base}layer_norm']['offset']\n",
    "        state_dict[f'{model_base}qkv_layer_norm.weight'] = params[f'{params_base}layer_norm_1']['scale']\n",
    "        state_dict[f'{model_base}qkv_layer_norm.bias'] = params[f'{params_base}layer_norm_1']['offset']\n",
    "    else:\n",
    "        state_dict[f'{model_base}q_layer_norm.weight'] = params[f'{params_base}layer_norm']['scale']\n",
    "        state_dict[f'{model_base}q_layer_norm.bias'] = params[f'{params_base}layer_norm']['offset']\n",
    "        state_dict[f'{model_base}kv_layer_norm.weight'] = params[f'{params_base}layer_norm_1']['scale']\n",
    "        state_dict[f'{model_base}kv_layer_norm.bias'] = params[f'{params_base}layer_norm_1']['offset']\n",
    "        state_dict[f'{model_base}qkv_layer_norm.weight'] = params[f'{params_base}layer_norm_2']['scale']\n",
    "        state_dict[f'{model_base}qkv_layer_norm.bias'] = params[f'{params_base}layer_norm_2']['offset']\n",
    "\n",
    "    state_dict[f'{model_base}mlp.mlp.0.weight'] = params[f'{params_base}mlp/linear']['w'].T\n",
    "    state_dict[f'{model_base}mlp.mlp.0.bias'] = params[f'{params_base}mlp/linear']['b']\n",
    "    state_dict[f'{model_base}mlp.mlp.2.weight'] = params[f'{params_base}mlp/linear_1']['w'].T\n",
    "    state_dict[f'{model_base}mlp.mlp.2.bias'] = params[f'{params_base}mlp/linear_1']['b']\n",
    "\n",
    "copy_attention_params(f'{model_enc_base}cross_attn.', f'{params_enc_base}cross_attention/')\n",
    "copy_attention_params(f'perceiver.decoder.cross_attention.', f'basic_decoder/cross_attention/')\n",
    "\n",
    "for i in range(26):\n",
    "    copy_attention_params(f'{model_enc_base}self_attention_block.{i}.', f'{params_enc_base}self_attention{\"_%d\"%i if i else \"\"}/')\n",
    "    \n",
    "state_dict = {k: torch.tensor(v) for k,v in state_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:17:41.200266Z",
     "start_time": "2021-11-18T20:17:40.004088Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:18:05.180146Z",
     "start_time": "2021-11-18T20:18:05.153492Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenized string without masked bytes:\n",
      "This is an incomplete sentence where some words are\n"
     ]
    }
   ],
   "source": [
    "from deepmind_assets import bytes_tokenizer\n",
    "import numpy as np\n",
    "MAX_SEQ_LEN = 2048\n",
    "# The tokenizer is just UTF-8 encoding (with an offset)\n",
    "tokenizer = bytes_tokenizer.BytesTokenizer()\n",
    "input_str = \"This is an incomplete sentence where some words are missing.\"\n",
    "input_tokens = tokenizer.to_int(input_str)\n",
    "\n",
    "# Mask \" missing.\". Note that the model performs much better if the masked chunk\n",
    "# starts with a space.\n",
    "input_tokens[51:60] = tokenizer.mask_token\n",
    "print(\"Tokenized string without masked bytes:\")\n",
    "print(tokenizer.to_string(input_tokens))\n",
    "\n",
    "#@title Pad and reshape inputs\n",
    "inputs = input_tokens[None]\n",
    "input_mask = np.ones_like(inputs)\n",
    "\n",
    "def pad(max_sequence_length: int, inputs, input_mask):\n",
    "    input_len = inputs.shape[1]\n",
    "    assert input_len <= max_sequence_length\n",
    "    pad_len = max_sequence_length - input_len\n",
    "    padded_inputs = np.pad(\n",
    "      inputs,\n",
    "      pad_width=((0, 0), (0, pad_len)),\n",
    "      constant_values=tokenizer.pad_token)\n",
    "    padded_mask = np.pad(\n",
    "      input_mask,\n",
    "      pad_width=((0, 0), (0, pad_len)),\n",
    "      constant_values=0)\n",
    "    return padded_inputs, padded_mask\n",
    "\n",
    "inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:18:27.094849Z",
     "start_time": "2021-11-18T20:18:23.023049Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Greedy predictions:\n",
      "tensor([ 38, 115, 111, 121, 121, 111, 116, 109,  52])\n",
      "\n",
      "Predicted string:\n",
      " missing.\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "out = model.forward(torch.tensor(inputs), torch.tensor(input_mask))\n",
    "\n",
    "masked_tokens_predictions = out[0, 51:60].argmax(dim=-1)\n",
    "print(\"Greedy predictions:\")\n",
    "print(masked_tokens_predictions)\n",
    "print()\n",
    "print(\"Predicted string:\")\n",
    "print(tokenizer.to_string(masked_tokens_predictions.cpu().detach().numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-11-18T20:19:31.769992Z",
     "start_time": "2021-11-18T20:19:31.268132Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-11.8336, -11.6850, -11.8483,  ..., -11.5524, -11.7844, -11.7093],\n",
       "         [-12.8149, -12.5863, -12.7904,  ..., -12.6056, -12.8341, -12.6477],\n",
       "         [-12.8440, -12.6410, -12.8646,  ..., -12.5758, -12.8579, -12.6943],\n",
       "         ...,\n",
       "         [-11.4762, -11.4972, -11.4584,  ..., -11.8531, -12.1219, -11.3704],\n",
       "         [-11.4762, -11.4972, -11.4584,  ..., -11.8531, -12.1219, -11.3704],\n",
       "         [-11.4762, -11.4972, -11.4584,  ..., -11.8531, -12.1219, -11.3704]]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
