{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6102dd3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ssh gpu6 \"source bulatov/memexp_env/bin/activate; cd bulatov/TXL/pytorch; bash run_wt103_base_memtrans_1by1.sh train --work_dir test_1by1 --device_ids 4 5 6 7 --data ~/bulatov/TXL/data/wt103\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8cc5facf",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([1.], requires_grad=True)\n",
    "x.grad = torch.tensor([1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "effb7f8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c8c2bc16",
   "metadata": {},
   "outputs": [],
   "source": [
    "l = x*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe7e612",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e6a2b4fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([torch.zeros(5, 4), torch.zeros(5, 4)]).sum(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5b805415",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "log_softmax() received an invalid combination of arguments - got (Tensor), but expected one of:\n * (int dim, torch.dtype dtype)\n * (name dim, *, torch.dtype dtype)\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-b5f823fa83cc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlog_softmax\u001b[0;34m(input, dim, _stacklevel, dtype)\u001b[0m\n\u001b[1;32m   1767\u001b[0m         \u001b[0mdim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_softmax_dim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"log_softmax\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_stacklevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1768\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1769\u001b[0;31m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1770\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1771\u001b[0m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: log_softmax() received an invalid combination of arguments - got (Tensor), but expected one of:\n * (int dim, torch.dtype dtype)\n * (name dim, *, torch.dtype dtype)\n"
     ]
    }
   ],
   "source": [
    "l = torch.nn.functional.log_softmax(x, 2*x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "43aed31c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1., grad_fn=<MseLossBackward0>)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([1.])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(l(x, x*2))\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "956bc892",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'MSELoss' object has no attribute 'backward'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-6511facd51ad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1175\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1176\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1177\u001b[0;31m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m   1178\u001b[0m             type(self).__name__, name))\n\u001b[1;32m   1179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'MSELoss' object has no attribute 'backward'"
     ]
    }
   ],
   "source": [
    "l.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a7ca5776",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ayd98/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)\n",
      "  return self._grad\n"
     ]
    }
   ],
   "source": [
    "l.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d7e2c64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# bash run_wt103_base.sh train --work_dir WORK_DIR --device_ids 0 1 --data PATH_TO_DATA --mem_len 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1817292f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('/home/ayd98/Desktop/MIPT/TXL/data/enwik8/train.txt', 'r') as f:\n",
    "#     res = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4ebb6a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import re\n",
    "# re.sub(' +', ' ', 'The     quick brown    fox')\n",
    "# res[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "edf5dd51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# res = list(res)\n",
    "# res = [s if s != ' ' else '\\spce' for s in res]\n",
    "# res = ' '.join(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "81d4f177",
   "metadata": {},
   "outputs": [],
   "source": [
    "# l = len(res)\n",
    "# train = res[:int(0.9*l)]\n",
    "# test = res[int(0.9*l):int(0.95*l)]\n",
    "# val = res[int(0.95*l):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ac96cd14",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for var, name in zip([train, test, val], ['train', 'test', 'valid']):\n",
    "#     with open(f'/home/ayd98/Desktop/MIPT/TXL/data/enwik8/{name}.txt', 'w') as f:\n",
    "#         f.write(var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8fde71f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/data\n",
    "# MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/model\n",
    "\n",
    "# # Model\n",
    "# N_LAYER=24\n",
    "# D_MODEL=1024\n",
    "# D_EMBED=1024\n",
    "# N_HEAD=8\n",
    "# D_HEAD=128\n",
    "# D_INNER=3072\n",
    "\n",
    "# # Testing\n",
    "# TEST_TGT_LEN=128\n",
    "# TEST_MEM_LEN=3800\n",
    "# TEST_CLAMP_LEN=1000\n",
    "\n",
    "# TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0\n",
    "# TEST_BSZ=16\n",
    "# TEST_NUM_CORE=2\n",
    "\n",
    "\n",
    "# python data_utils.py \\\n",
    "#   --data_dir=$'data/enwik8' \\\n",
    "#   --dataset=enwik8 \\\n",
    "#   --tgt_len=${TEST_TGT_LEN} \\\n",
    "#   --per_host_test_bsz=${TEST_BSZ} \\\n",
    "#   --num_passes=1 \\\n",
    "#   --use_tpu=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "45321acc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torchtext.datasets as datasets\n",
    "\n",
    "# sets = datasets.WikiText103(root='data/wt103', split=('train', 'valid', 'test'))\n",
    "\n",
    "# gen = sets[0]\n",
    "# with open('data/wt103/train.txt', 'w') as f:\n",
    "#     for i, t in enumerate(gen):\n",
    "#         f.write(t)\n",
    "\n",
    "# gen = sets[1]\n",
    "# with open('data/wt103/valid.txt', 'w') as f:\n",
    "#     for i, t in enumerate(gen):\n",
    "#         f.write(t)\n",
    "\n",
    "# gen = sets[2]\n",
    "# with open('data/wt103/test.txt', 'w') as f:\n",
    "#     for i, t in enumerate(gen):\n",
    "#         f.write(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "65ae757f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 - int(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a57acd04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 2., 2., 2., 2., 2.],\n",
       "        [0., 0., 0., 0., 2., 2., 2., 2.],\n",
       "        [0., 0., 0., 0., 0., 2., 2., 2.],\n",
       "        [0., 0., 0., 1., 1., 1., 1., 1.],\n",
       "        [0., 0., 0., 1., 1., 1., 1., 1.],\n",
       "        [0., 0., 0., 1., 1., 1., 1., 1.]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask = torch.triu(torch.ones(6,8)*2, diagonal=3)\n",
    "mask[2+1:, 2+1:] = 1\n",
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aa0f8329",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(torch.ones(2) * 2 ).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b9dfb631",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "import torch\n",
    "import sys\n",
    "sys.path.append('pytorch')\n",
    "sys.path.append('pytorch/utils')\n",
    "\n",
    "from pytorch.mem_transformer import *\n",
    "from pytorch import data_utils\n",
    "\n",
    "from experiment_utils.run_experiment import *\n",
    "from experiment_utils.generate_data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "746e6986",
   "metadata": {},
   "source": [
    "## Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1de49706",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "TAG = '10tkn_len24_ext'\n",
    "\n",
    "TASK_NAME = 'copy'\n",
    "TRAIN_SIZE = 1000\n",
    "VAL_SIZE = 200\n",
    "TEST_SIZE = 100\n",
    "NUM_INITS = 1\n",
    "\n",
    "\n",
    "NUM_BATCHES = int(4e5)\n",
    "BATCH_SIZE = 128\n",
    "GENERATE_EVERY  = 10000\n",
    "NUM_TOKENS = 10 + 2\n",
    "ENC_SEQ_LEN = 24\n",
    "DEC_SEQ_LEN = 48\n",
    "\n",
    "INPUT_LEN = 24"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a87cb25",
   "metadata": {},
   "source": [
    "#### Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "31b7097a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # !mkdir data24\n",
    "# np.random.seed(42)\n",
    "\n",
    "# generator = copy_generator(batch_size=BATCH_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN, num_tokens=NUM_TOKENS)\n",
    "# generate_data(generator, path=f'data{INPUT_LEN}', task_name=TASK_NAME, train_size=TRAIN_SIZE, test_size=TEST_SIZE, val_size=VAL_SIZE, batch_size=BATCH_SIZE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "17253635",
   "metadata": {},
   "outputs": [],
   "source": [
    "class data_loader:\n",
    "    def __init__(self, task_name, path='data', batch_size=32, none_mask=True):\n",
    "        self.X, self.y = np.load(f'{path}/{task_name}_X.npy'), np.load(f'{path}/{task_name}_y.npy')\n",
    "        self.data_size = self.X.shape[0]\n",
    "        self.data_ptr = 0\n",
    "\n",
    "        if none_mask:\n",
    "            self.src_mask, self.tgt_mask = None, None\n",
    "        else:\n",
    "            self.src_masks, self.tgt_mask = np.load(f'{path}/{task_name}_mask.npy'), None\n",
    "\n",
    "        self.batch_size = batch_size\n",
    "        self.none_mask = none_mask\n",
    "\n",
    "    def __next__(self):\n",
    "        X = self.X[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "        y = self.y[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "        \n",
    "        if not self.none_mask:\n",
    "            sm = self.src_masks[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "            sm = torch.tensor(sm).cuda()\n",
    "        else:\n",
    "            sm = None\n",
    "            \n",
    "        self.data_ptr = (self.data_ptr + self.batch_size) % self.data_size\n",
    "\n",
    "        return torch.tensor(X),\\\n",
    "                torch.tensor(y),\\\n",
    "                sm, self.tgt_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c118128",
   "metadata": {},
   "source": [
    "### Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "082c59ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_parameters = ParameterGrid({\n",
    "                'n_layer': [2],\n",
    "                'n_head': [4],\n",
    "                'd_head': [128],\n",
    "                'num_mem_tokens': [0],\n",
    "                'mem_len': [10]})\n",
    "\n",
    "param = list(model_parameters)[0]\n",
    "\n",
    "fixed_parameters = {'n_token': NUM_TOKENS,\n",
    "                    'd_model': param['d_head'],# + param['num_mem_tokens']-1,\n",
    "                    'd_inner': param['d_head'],\n",
    "                    'dropout': 0,\n",
    "                    'dropatt': 0,\n",
    "                    'tie_weight': True,\n",
    "                    'div_val': 1, # ???????\n",
    "                    'tie_projs': [False],\n",
    "                    'tgt_len': DEC_SEQ_LEN,\n",
    "                    'ext_len': 0, \n",
    "                    'cutoffs': [],\n",
    "                    'attn_type': 0,}\n",
    "\n",
    "model = MemTransformerLM(**param, **fixed_parameters)#.cuda()\n",
    "\n",
    "gen_train = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_train', batch_size=21)\n",
    "src, tgt, _, _ = next(gen_train)\n",
    "# src, tgt = src.cuda(), tgt.cuda()\n",
    "src, tgt = src.cpu().T, tgt.cpu().T\n",
    "\n",
    "mems = tuple()\n",
    "# model(src, tgt.contiguous(), *mems)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a5a5d96b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/wt103/test.txt', 'r') as f:\n",
    "    t = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4c277e19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' \\n = Robert Boulter = \\n \\n Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John Deed in 2002 . In 2004 Boulter landed a role as \" Craig \" in the episode \" Teddy \\'s Story \" of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the <unk> Chocolate Factory in London . He was directed by John Tiffany and starred alongside Ben Whishaw , Shane Zaza , Harry Kent , Fraser Ayres , Sophie Stanton and Dominic Hall . \\n In 2006 , Boulter starred alongside Whishaw in the play Citizenship written by Mark Ravenhill . He appeared on a 200'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t[:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5399822c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Producing dataset wt103...\n",
      "building vocab with min_freq=0, max_size=None\n",
      "final vocab size 267735 from 267734 unique tokens\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'args' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-27fd60e76918>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mcorpus\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_lm_corpus\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data/wt103/'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wt103'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mntokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorpus\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_token\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mntokens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0meval_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'args' is not defined"
     ]
    }
   ],
   "source": [
    "from data_utils import *\n",
    "\n",
    "corpus = get_lm_corpus('data/wt103/', 'wt103')\n",
    "ntokens = len(corpus.vocab)\n",
    "args.n_token = ntokens\n",
    "\n",
    "eval_batch_size = 10\n",
    "tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,\n",
    "    device=device, ext_len=args.ext_len)\n",
    "va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,\n",
    "    device=device, ext_len=args.ext_len)\n",
    "te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,\n",
    "    device=device, ext_len=args.ext_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d95852cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "it = corpus.get_iterator('train', 21, 150,\n",
    "    device='cpu', ext_len=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "709f493f",
   "metadata": {},
   "outputs": [],
   "source": [
    "lt = list(enumerate(it))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "6b5df850",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[     7,     24,   1072,   1483,   1944],\n",
       "        [   286,      2,   3977,      2,      3],\n",
       "        [     2,   1532,  38476,      1,    332],\n",
       "        [  2377, 114037,      5,   7703,   1445],\n",
       "        [    72,      0,  15128,  18109,     22],\n",
       "        [     8,   1532,  37946,  12987,  46597],\n",
       "        [   211,     24,   7426,  24601,   4597],\n",
       "        [  1410,      2,     16,      5,   1233],\n",
       "        [     4,   1532,      1,      1,      3],\n",
       "        [     1,     24,    161,  20224,     21]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src, tgt, l = lt[0][1]\n",
    "src[-10:, :5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6b5df850",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  1410,      2,     16,      5,   1233],\n",
       "        [     4,   1532,      1,      1,      3],\n",
       "        [     1,     24,    161,  20224,     21],\n",
       "        [   141,      2,     11,   1621,   4265],\n",
       "        [   992,   1532,     38,   8866,     27],\n",
       "        [    15,     24,    587,   8353,      8],\n",
       "        [ 52174,      0,   1101,      1,    884],\n",
       "        [ 11392,   1532,     17,    732,    118],\n",
       "        [   358, 180578,     56,   3660,      4],\n",
       "        [     3,      2,  16986,    334,   7138]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src, tgt, l = lt[1][1]\n",
    "src[:10, :5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "c182ddf6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([24, 21])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cbc3cde2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ayd98/Desktop/MIPT/TXL/pytorch/mem_transformer.py:267: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at  ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:1273.)\n",
      "  attn_score = attn_score.float().masked_fill(\n"
     ]
    }
   ],
   "source": [
    "mems = model.init_mems(src.device)\n",
    "res = model._forward(src, mems)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "a21d8dd3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([24, 21, 128]),\n",
       " [torch.Size([10, 21, 128]),\n",
       "  torch.Size([10, 21, 128]),\n",
       "  torch.Size([10, 21, 128])])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[0].shape, [r.shape for r in res[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6f850e16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[[  6.8407,   8.3905,   0.9672,  ...,   8.8531,   5.1211, -14.8174],\n",
       "          [  6.8407,   8.3905,   0.9672,  ...,   8.8531,   5.1211, -14.8174],\n",
       "          [  6.8407,   8.3905,   0.9672,  ...,   8.8531,   5.1211, -14.8174],\n",
       "          ...,\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291],\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291],\n",
       "          [-24.5933,   0.5914,   0.3800,  ...,  -2.4394,  -3.9996, -29.2228]],\n",
       " \n",
       "         [[  6.8407,   8.3905,   0.9672,  ...,   8.8531,   5.1211, -14.8174],\n",
       "          [  0.8790,  -2.7545,   1.5977,  ...,  11.7079, -12.7973,  -4.8640],\n",
       "          [  0.8790,  -2.7545,   1.5977,  ...,  11.7079, -12.7973,  -4.8640],\n",
       "          ...,\n",
       "          [ -1.6289, -18.3521,  -4.9785,  ...,   4.8849,  -1.5276,   0.5263],\n",
       "          [ -4.0415,   0.3774,  11.2921,  ...,   6.9845,   7.1489,  -2.0019],\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291]],\n",
       " \n",
       "         [[  0.8790,  -2.7545,   1.5977,  ...,  11.7079, -12.7973,  -4.8640],\n",
       "          [ 17.8083, -10.4142, -10.0157,  ...,  -4.9378,   9.3704,   3.5837],\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291],\n",
       "          ...,\n",
       "          [ 17.8083, -10.4142, -10.0157,  ...,  -4.9378,   9.3704,   3.5837],\n",
       "          [  0.8790,  -2.7545,   1.5977,  ...,  11.7079, -12.7973,  -4.8640],\n",
       "          [  0.8790,  -2.7545,   1.5977,  ...,  11.7079, -12.7973,  -4.8640]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[  5.1366,  -5.2081,  -6.0639,  ...,   4.5898,   5.8368,  -6.5319],\n",
       "          [-24.7386,  -1.5784,  21.2404,  ..., -17.6516,   7.0825,  -3.1931],\n",
       "          [ -1.6289, -18.3521,  -4.9785,  ...,   4.8849,  -1.5276,   0.5263],\n",
       "          ...,\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291],\n",
       "          [ -2.5426,  -9.4414,  -6.9474,  ...,  -8.7893,  -8.3210,  21.0291],\n",
       "          [ -4.0415,   0.3774,  11.2921,  ...,   6.9845,   7.1489,  -2.0019]],\n",
       " \n",
       "         [[-24.5933,   0.5914,   0.3800,  ...,  -2.4394,  -3.9996, -29.2228],\n",
       "          [ -9.0818, -15.5877, -15.5770,  ...,  10.3125,  17.6973,   6.9862],\n",
       "          [ -1.6289, -18.3521,  -4.9785,  ...,   4.8849,  -1.5276,   0.5263],\n",
       "          ...,\n",
       "          [ -4.0415,   0.3774,  11.2921,  ...,   6.9845,   7.1489,  -2.0019],\n",
       "          [ -1.6289, -18.3521,  -4.9785,  ...,   4.8849,  -1.5276,   0.5263],\n",
       "          [  6.8407,   8.3905,   0.9672,  ...,   8.8531,   5.1211, -14.8174]],\n",
       " \n",
       "         [[ -4.0415,   0.3774,  11.2921,  ...,   6.9845,   7.1489,  -2.0019],\n",
       "          [ -1.6289, -18.3521,  -4.9785,  ...,   4.8849,  -1.5276,   0.5263],\n",
       "          [-24.5933,   0.5914,   0.3800,  ...,  -2.4394,  -3.9996, -29.2228],\n",
       "          ...,\n",
       "          [ 17.8083, -10.4142, -10.0157,  ...,  -4.9378,   9.3704,   3.5837],\n",
       "          [  5.1366,  -5.2081,  -6.0639,  ...,   4.5898,   5.8368,  -6.5319],\n",
       "          [ -9.0818, -15.5877, -15.5770,  ...,  10.3125,  17.6973,   6.9862]]]),\n",
       " tensor([[[ 1.2785e+00,  5.9332e-01,  4.2627e-02,  ...,  1.4906e+00,\n",
       "            2.3133e-01, -1.2870e+00],\n",
       "          [ 1.2347e+00,  5.8977e-01, -4.5671e-02,  ...,  1.3704e+00,\n",
       "            7.4016e-01, -1.1662e+00],\n",
       "          [ 1.2723e+00,  7.3471e-01,  8.6937e-02,  ...,  1.1889e+00,\n",
       "            3.2899e-01, -1.4282e+00],\n",
       "          ...,\n",
       "          [-6.2175e-01, -9.5482e-01, -5.6799e-01,  ..., -9.4344e-02,\n",
       "           -4.2523e-01,  2.3610e+00],\n",
       "          [-5.2416e-01, -8.8864e-01, -7.8648e-01,  ..., -7.2400e-01,\n",
       "           -5.1085e-01,  1.6938e+00],\n",
       "          [-1.7181e+00,  3.0203e-01,  3.2562e-01,  ..., -2.5607e-01,\n",
       "            2.9916e-01, -2.2527e+00]],\n",
       " \n",
       "         [[ 1.1854e+00,  4.1617e-01, -1.5536e-02,  ...,  1.3224e+00,\n",
       "            1.5059e-01, -1.6926e+00],\n",
       "          [ 4.5696e-02, -4.6233e-01, -2.4409e-01,  ...,  9.7856e-01,\n",
       "           -9.5832e-01, -3.9013e-01],\n",
       "          [ 2.7974e-02, -6.4812e-01, -3.3786e-01,  ...,  6.1185e-01,\n",
       "           -1.1907e+00, -7.3533e-01],\n",
       "          ...,\n",
       "          [ 2.8524e-02, -1.4939e+00, -2.5522e-01,  ...,  7.2474e-01,\n",
       "            1.6038e-02,  1.0347e-01],\n",
       "          [-2.3675e-01, -4.4639e-01,  1.5938e+00,  ...,  1.3809e+00,\n",
       "            8.9056e-01, -1.9522e-01],\n",
       "          [-4.2677e-01, -9.3345e-01, -4.9776e-01,  ..., -8.3493e-01,\n",
       "            3.0709e-03,  2.0261e+00]],\n",
       " \n",
       "         [[ 1.2220e-03, -3.1906e-01, -3.2837e-01,  ...,  4.9292e-01,\n",
       "           -5.7564e-01, -3.2458e-01],\n",
       "          [ 1.1925e+00, -1.3974e+00, -7.0635e-01,  ..., -2.5182e-01,\n",
       "            9.0700e-01,  4.8928e-01],\n",
       "          [-6.9065e-01, -1.3946e+00, -1.4612e-01,  ..., -8.0589e-01,\n",
       "           -7.2277e-01,  1.5892e+00],\n",
       "          ...,\n",
       "          [ 1.3485e+00, -8.6850e-01, -9.3196e-01,  ..., -6.1315e-01,\n",
       "            1.4623e+00,  5.1686e-01],\n",
       "          [ 8.2957e-02, -4.6859e-01,  2.2191e-01,  ...,  1.2775e+00,\n",
       "           -1.1926e+00, -1.6271e-01],\n",
       "          [-1.8239e-01, -2.0843e-01,  2.7331e-01,  ...,  9.2431e-01,\n",
       "           -1.0230e+00,  1.1980e-01]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 7.8928e-01, -1.0479e+00, -4.2151e-01,  ...,  4.8231e-01,\n",
       "            1.0443e+00, -5.4026e-01],\n",
       "          [-2.0388e+00, -2.3192e-01,  1.5957e+00,  ..., -1.8473e+00,\n",
       "            2.3316e-01, -1.3483e-01],\n",
       "          [-1.7782e-01, -1.7108e+00, -8.3758e-02,  ..., -2.4040e-01,\n",
       "           -7.4238e-02,  4.2087e-01],\n",
       "          ...,\n",
       "          [-3.9208e-01, -5.0616e-01, -3.5195e-01,  ..., -6.5545e-01,\n",
       "           -1.0406e+00,  1.5654e+00],\n",
       "          [-3.9258e-01, -8.1834e-01, -6.5740e-01,  ..., -8.6025e-01,\n",
       "           -6.5001e-01,  1.7907e+00],\n",
       "          [-4.4208e-01, -5.8736e-01,  1.0567e+00,  ...,  2.2531e-01,\n",
       "            7.9491e-01, -4.0145e-01]],\n",
       " \n",
       "         [[-1.7963e+00,  2.0393e-01,  8.8917e-02,  ..., -5.5250e-01,\n",
       "           -5.6517e-01, -2.4461e+00],\n",
       "          [-8.7755e-01, -1.1344e+00, -1.6494e+00,  ...,  6.3753e-01,\n",
       "            7.0964e-01,  5.9973e-01],\n",
       "          [-8.3459e-02, -1.9525e+00, -2.0927e-01,  ...,  2.4299e-01,\n",
       "            4.8310e-01,  4.5441e-01],\n",
       "          ...,\n",
       "          [-7.0692e-03, -2.1543e-02,  1.0678e+00,  ...,  1.1633e+00,\n",
       "            6.9272e-01,  1.0799e-01],\n",
       "          [-1.2540e-01, -1.5891e+00, -4.2694e-01,  ...,  3.9266e-01,\n",
       "            1.7103e-01,  2.7676e-01],\n",
       "          [ 8.6784e-01,  4.0098e-01,  2.6664e-01,  ...,  6.2947e-01,\n",
       "            2.8081e-01, -1.2663e+00]],\n",
       " \n",
       "         [[-2.1091e-01, -3.5306e-02,  1.0995e+00,  ...,  7.7252e-01,\n",
       "            9.9584e-01, -1.0183e-01],\n",
       "          [-6.8453e-02, -1.5333e+00, -2.3162e-01,  ...,  7.3537e-01,\n",
       "            3.8302e-01,  4.1425e-01],\n",
       "          [-1.4437e+00,  9.4299e-02,  2.2928e-01,  ..., -1.0674e-01,\n",
       "            7.2555e-01, -2.4190e+00],\n",
       "          ...,\n",
       "          [ 1.2347e+00, -1.1191e+00, -9.6012e-01,  ..., -3.7186e-01,\n",
       "            1.0626e+00,  6.1096e-01],\n",
       "          [ 5.6683e-01, -3.6975e-01, -4.5467e-02,  ...,  2.9452e-01,\n",
       "            8.2185e-01, -3.2482e-01],\n",
       "          [-8.9454e-01, -1.4466e+00, -2.0806e+00,  ...,  8.2537e-01,\n",
       "            8.1450e-01,  8.3462e-01]]]),\n",
       " tensor([[[ 0.7041,  0.6989,  0.1407,  ...,  1.3700,  0.3641, -1.0948],\n",
       "          [ 0.3603,  0.9199,  0.0094,  ...,  1.4011,  0.4964, -0.6155],\n",
       "          [ 0.5805,  0.9393,  0.5422,  ...,  1.3410,  0.2879, -0.8131],\n",
       "          ...,\n",
       "          [-0.5858, -0.7802, -0.4361,  ...,  0.2318, -0.1225,  1.7490],\n",
       "          [ 0.0616, -0.6292, -0.2605,  ..., -0.3716, -0.6563,  1.7631],\n",
       "          [-1.3702,  0.0744,  1.2398,  ..., -0.3798,  0.5016, -2.0123]],\n",
       " \n",
       "         [[ 0.5109,  0.0738,  0.2555,  ...,  1.3590,  0.7060, -1.3735],\n",
       "          [-0.4136,  0.0278,  0.3721,  ...,  1.4233, -1.2898, -0.0424],\n",
       "          [-0.0640, -0.3587, -0.0433,  ...,  0.7139, -1.0874, -0.8693],\n",
       "          ...,\n",
       "          [ 0.2934, -1.6926, -0.4154,  ...,  0.6257,  0.5591,  0.0217],\n",
       "          [-0.7373, -0.4493,  1.3318,  ...,  1.7896,  0.2605,  0.2831],\n",
       "          [-0.2578, -0.2940, -0.6409,  ..., -0.9671,  0.0165,  2.5596]],\n",
       " \n",
       "         [[ 0.4024,  0.1275,  0.0086,  ...,  0.2672,  0.0070,  0.1657],\n",
       "          [ 0.5027, -0.6764, -0.0071,  ...,  0.0413,  0.1995,  0.8950],\n",
       "          [-0.1509, -1.2011, -0.0354,  ..., -0.9556, -0.5193,  1.6732],\n",
       "          ...,\n",
       "          [ 0.9319, -0.1354, -0.2405,  ..., -0.8186,  1.2471,  1.2072],\n",
       "          [-0.1635, -0.1610,  1.1424,  ...,  1.2271, -0.9374, -0.4952],\n",
       "          [-0.7790,  0.3632,  0.9243,  ...,  0.9111, -0.5122, -0.1068]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 0.6326, -1.0416, -0.5684,  ...,  0.2937,  1.3642, -0.2104],\n",
       "          [-2.0443, -0.5180,  2.5247,  ..., -1.4989,  0.2000, -0.0842],\n",
       "          [-0.6140, -1.0049, -0.0540,  ..., -0.1409,  0.1465,  0.7793],\n",
       "          ...,\n",
       "          [-0.0998, -0.7187,  0.0173,  ..., -0.5667, -0.5233,  1.6767],\n",
       "          [-0.1905, -0.5788, -0.4652,  ..., -1.1091, -0.6468,  1.7087],\n",
       "          [-0.4955, -0.2691,  0.3105,  ..., -0.6354,  0.8244, -0.5577]],\n",
       " \n",
       "         [[-1.9714, -0.1760,  0.5144,  ..., -0.6304, -0.3655, -1.6228],\n",
       "          [-1.1057, -0.0922, -1.2548,  ...,  0.6069,  0.5572,  1.2498],\n",
       "          [-0.7914, -1.8775, -0.5085,  ...,  0.4845,  0.4712,  0.0513],\n",
       "          ...,\n",
       "          [-0.0647,  0.1347,  0.5888,  ...,  1.2965,  1.0676,  0.9392],\n",
       "          [-0.4968, -1.3124, -0.6181,  ...,  0.5184,  0.3255,  0.3190],\n",
       "          [-0.0116,  0.9168,  0.5130,  ...,  0.6351,  0.1532, -1.0923]],\n",
       " \n",
       "         [[-0.2400,  0.3152,  0.9115,  ...,  0.6669,  0.5959,  0.5056],\n",
       "          [-0.9522, -1.4066, -0.1097,  ...,  1.1949,  0.1393,  0.5310],\n",
       "          [-1.3984,  0.4449,  0.8298,  ..., -0.0266,  0.5308, -1.8083],\n",
       "          ...,\n",
       "          [ 1.0714, -0.5544,  0.0158,  ..., -0.3898,  1.3635,  0.8936],\n",
       "          [ 0.0470, -0.4521,  0.4500,  ...,  0.2632,  1.1812,  0.2824],\n",
       "          [-0.7854, -0.7039, -2.1405,  ...,  0.8198,  0.9147,  1.1584]]])]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ffa64210",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'shape'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-14-776dc3d37f9f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmems\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'shape'"
     ]
    }
   ],
   "source": [
    "out.shape, mems.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "ae2a3cd4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dp_model.module.num_mem_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "35079c41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hasattr(model, \"num_mem_tokens_\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1b7f9164",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([24, 21])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "070c2634",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([504]), torch.Size([504]))"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src.contiguous().view(-1).shape, src.reshape(-1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5aab019c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<generator object Module.parameters at 0x7fc4ba0e0190>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.layers[0].parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8bd0f1c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9088, -0.4226, -0.1848, -0.7608, -0.4045]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randn(1, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f495e145",
   "metadata": {},
   "outputs": [],
   "source": [
    "# self = model\n",
    "# target = tgt.contiguous()\n",
    "# data = src\n",
    "# mem_tokens = None\n",
    "\n",
    "\n",
    "# if not mems: mems = self.init_mems()\n",
    "# if self.mem_tokens is None: self.init_mem_tokens()\n",
    "\n",
    "# tgt_len = target.size(0)\n",
    "# hidden, new_mems = self._forward(data, mems=mems, mem_tokens=mem_tokens)\n",
    "# pred_hid = hidden[-tgt_len:]\n",
    "# mem_tokens = hidden[-tgt_len - self.num_mem_tokens: -tgt_len]\n",
    "# # if self.sample_softmax > 0 and self.training:\n",
    "# #     assert self.tie_weight\n",
    "# #     logit = sample_logits(self.word_emb,\n",
    "# #         self.out_layer.bias, target, pred_hid, self.sampler)\n",
    "# #     loss = -F.log_softmax(logit, -1)[:, :, 0]\n",
    "# # else:\n",
    "# loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))\n",
    "# loss = loss.view(tgt_len, -1)\n",
    "\n",
    "# output = [loss]\n",
    "\n",
    "# if new_mems is not None:\n",
    "#     output += new_mems\n",
    "# if self.num_mem_tokens not in (0, None):\n",
    "#     output = [mem_tokens] + output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4d2bff68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tgt_len, hidden.shape, pred_hid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3bec0086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pred_hid.view(-1, pred_hid.size(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "47ab465f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# {**param, **fixed_parameters}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "41a54127",
   "metadata": {},
   "outputs": [],
   "source": [
    "# out = model(src, src.contiguous())\n",
    "# mem_tokens, loss, mems = out[0], out[1], out[2:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e7860bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerXL(MemTransformerLM):\n",
    "    def __init__(self, enc_kwargs, dec_kwargs):\n",
    "        super().__init__(**enc_kwargs)\n",
    "        \n",
    "        self.Encoder = MemTransformerLM(**enc_kwargs)\n",
    "        self.Decoder = MemTransformerLM(**dec_kwargs)\n",
    "\n",
    "        self.Encoder.init_mem_tokens()\n",
    "        self.Decoder.init_mem_tokens()\n",
    "\n",
    "        self.enc_kwargs = enc_kwargs\n",
    "        self.dec_kwargs = dec_kwargs\n",
    "\n",
    "    def forward(self, data, target, mems=None, mem_tokens=None):\n",
    "\n",
    "        if not mems: mems = self.Encoder.init_mems()\n",
    "        if self.Encoder.mem_tokens is None: self.Encoder.init_mem_tokens()\n",
    "\n",
    "        tgt_len = target.size(0)\n",
    "        hidden, enc_mems = self.Encoder._forward(src, mems=mems, mem_tokens=mem_tokens)\n",
    "\n",
    "        pred_hid = hidden[-tgt_len:]\n",
    "        mem_tokens = hidden[-tgt_len - self.Encoder.num_mem_tokens: -tgt_len]\n",
    "        \n",
    "\n",
    "        if not enc_mems: enc_mems = self.Encoder.init_mems()\n",
    "        if self.Decoder.mem_tokens is None: self.Decoder.init_mem_tokens()\n",
    "        \n",
    "        tgt_st = tgt[:-1]\n",
    "        tgt_gt = tgt[1: ]\n",
    "        \n",
    "        pred, dec_mems = self.Decoder._forward(tgt_st, mems=mems, mem_tokens=mem_tokens, context=pred_hid)\n",
    "\n",
    "        return pred, dec_mems\n",
    "        loss = Loss(pred, tgt_gt)\n",
    "        # num_mem = self.enc_kwargs['num_mem_tokens']\n",
    "        # if num_mem not in (None, 0):\n",
    "        #     mem, enc = enc[:num_mem], enc[num_mem:]\n",
    "\n",
    "        # if self.dec_kwargs['mem_len'] == 0:\n",
    "        #     mems = None\n",
    "\n",
    "\n",
    "        # start_tokens = tgt[:1, :]\n",
    "        # out, new_mems = self.Decoder._forward(start_tokens)\n",
    "        \n",
    "        # # loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))\n",
    "\n",
    "        output = [loss]\n",
    "\n",
    "        if new_mems is not None:\n",
    "            output += new_mems\n",
    "        if self.num_mem_tokens not in (0, None):\n",
    "            output = [mem_tokens] + output\n",
    "            \n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2739e2ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_param = {**param, **fixed_parameters}\n",
    "dec_param = {**param, **fixed_parameters}\n",
    "\n",
    "model = TransformerXL(enc_param, dec_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "681d728e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ayd98/Desktop/MIPT/TXL/pytorch/mem_transformer.py:281: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at  /pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1104.)\n",
      "  attn_score = attn_score.float().masked_fill(\n"
     ]
    }
   ],
   "source": [
    "mems = tuple()\n",
    "hidden, mems = model.Encoder._forward(src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4271bfd5",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (33) must match the size of tensor b (48) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-2520061434df>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtgt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1052\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-12-3ffeeabf6d72>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, data, target, mems, mem_tokens)\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0mtgt_gt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtgt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m         \u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_mems\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDecoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtgt_st\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmems\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmems\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmem_tokens\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpred_hid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_mems\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/MIPT/TXL/pytorch/mem_transformer.py\u001b[0m in \u001b[0;36m_forward\u001b[0;34m(self, dec_inp, mems, mem_tokens, context)\u001b[0m\n\u001b[1;32m    718\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    719\u001b[0m                 \u001b[0mmems_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmems\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mmems\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 720\u001b[0;31m                 core_out = layer(core_out, pos_emb, self.r_w_bias,\n\u001b[0m\u001b[1;32m    721\u001b[0m                         self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i, context=context)\n\u001b[1;32m    722\u001b[0m                 \u001b[0mhids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1052\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/MIPT/TXL/pytorch/mem_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask, mems, context)\u001b[0m\n\u001b[1;32m    447\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_w_bias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_r_bias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_attn_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmems\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    448\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 449\u001b[0;31m         output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,\n\u001b[0m\u001b[1;32m    450\u001b[0m                                \u001b[0mattn_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdec_attn_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    451\u001b[0m                                mems=mems, context=context)\n",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1052\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/MIPT/TXL/pytorch/mem_transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, w, r, r_w_bias, r_r_bias, attn_mask, mems, context)\u001b[0m\n\u001b[1;32m    270\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    271\u001b[0m         \u001b[0;31m# [qlen x klen x bsz x n_head]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 272\u001b[0;31m         \u001b[0mattn_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAC\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mBD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    273\u001b[0m         \u001b[0mattn_score\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    274\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (33) must match the size of tensor b (48) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "model(src, tgt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "da86adf1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([33, 21, 128])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hidden.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef8f146d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "235c2b40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fbb40ae",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "855d9f27",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03fa53c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "307f8b65",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92978633",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_train = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE)\n",
    "gen_val = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE)\n",
    "gen_test = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE)\n",
    "\n",
    "\n",
    "print_file = f'logs/{TASK_NAME}_{TAG}_memory_logs.txt'\n",
    "t = time.time()\n",
    "with torch.cuda.device(0):\n",
    "    for init_num in range(NUM_INITS):\n",
    "        with open(print_file, 'a') as f:\n",
    "            f.write('\\n\\nInit number ' + str(init_num)+'\\n')\n",
    "        for i, param in enumerate(list(model_parameters)):\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write('\\n\\n' + str(param)+'\\n')\n",
    "            param['enc_depth'], param['enc_heads'] = param['depth,heads']\n",
    "            param['dec_depth'], param['dec_heads'] = param['depth,heads']\n",
    "            param.pop('depth,heads')\n",
    "\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write(f'{i / len(model_parameters) * 100}%')\n",
    "            model = TransformerXL\n",
    "            model = XTransformer(**param).cuda()\n",
    "\n",
    "            model_name = f\"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_{TAG}_v{init_num}\"\n",
    "\n",
    "            optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "            train_validate_model(model, \n",
    "                            train_generator=gen_train, \n",
    "                            val_generator=gen_val, \n",
    "                            optim=optim, \n",
    "                            model_name=model_name, \n",
    "                            config=param,\n",
    "                            num_batches=NUM_BATCHES,\n",
    "                            generate_every=GENERATE_EVERY,\n",
    "                            print_file=print_file,\n",
    "                            tag=TAG,\n",
    "                            overfit_stop=False)\n",
    "            test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG)\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write(f'\\nTotal time: {time.time() - t}\\n')\n",
    "            t = time.time()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
