{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca832428",
   "metadata": {},
   "source": [
    "# data generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c999ab90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# coding: utf-8\n",
    "import argparse\n",
    "from genericpath import exists\n",
    "import time\n",
    "import math\n",
    "import os, sys\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "sys.path.append('pytorch')\n",
    "sys.path.append('pytorch/utils')\n",
    "from data_utils import *\n",
    "from mem_transformer import *\n",
    "from utils.exp_utils import *\n",
    "from utils.data_parallel import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ef11a660",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# if args.d_embed < 0:\n",
    "#     args.d_embed = args.d_model\n",
    "\n",
    "# assert args.ext_len >= 0, 'extended context length must be non-negative'\n",
    "# assert args.batch_size % args.batch_chunk == 0\n",
    "\n",
    "# args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)\n",
    "# args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))\n",
    "# logging = create_exp_dir(args.work_dir,\n",
    "#     scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)\n",
    "\n",
    "# # Set the random seed manually for reproducibility.\n",
    "# np.random.seed(args.seed)\n",
    "# torch.manual_seed(args.seed)\n",
    "# if torch.cuda.is_available():\n",
    "#     if not args.cuda:\n",
    "#         print('WARNING: You have a CUDA device, so you should probably run with --cuda')\n",
    "#     else:\n",
    "#         torch.cuda.manual_seed_all(args.seed)\n",
    "\n",
    "# # Validate `--fp16` option\n",
    "# if args.fp16:\n",
    "#     if not args.cuda:\n",
    "#         print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')\n",
    "#         args.fp16 = False\n",
    "#     else:\n",
    "#         try:\n",
    "#             from apex.fp16_utils import FP16_Optimizer\n",
    "#         except:\n",
    "#             print('WARNING: apex not installed, ignoring --fp16 option')\n",
    "#             args.fp16 = False\n",
    "\n",
    "# device = torch.device('cuda' if args.cuda else 'cpu')\n",
    "\n",
    "###############################################################################\n",
    "# Load data\n",
    "###############################################################################\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "80165998",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Producing dataset enwik8...\n",
      "building vocab with min_freq=0, max_size=None\n",
      "final vocab size 1356566 from 1356566 unique tokens\n"
     ]
    }
   ],
   "source": [
    "\n",
    "class dummy:\n",
    "    def __init__(self) -> None:\n",
    "        pass\n",
    "\n",
    "args = dummy()\n",
    "args.data = 'data/'\n",
    "args.dataset = 'enwik8'\n",
    "args.batch_size = 128\n",
    "args.tgt_len = 51\n",
    "device = 'cpu'\n",
    "args.ext_len = 5\n",
    "\n",
    "\n",
    "corpus = get_lm_corpus(args.data, args.dataset)\n",
    "\n",
    "ntokens = len(corpus.vocab)\n",
    "args.n_token = ntokens\n",
    "\n",
    "eval_batch_size = 10\n",
    "# tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,\n",
    "#     device=device, ext_len=args.ext_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9c392dcd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('enwik8', 'enwik8')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.dataset, corpus.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "981205b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_iterator(self, split, *args, **kwargs):\n",
    "self = corpus\n",
    "split = 'train'\n",
    "\n",
    "args = {}\n",
    "kwargs = {'bsz': 128, 'bptt':1}\n",
    "if split == 'train':\n",
    "    if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:\n",
    "        data_iter = LMOrderedIterator(self.train, *args, **kwargs)\n",
    "    elif self.dataset == 'lm1b':\n",
    "        kwargs['shuffle'] = True\n",
    "        data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)\n",
    "elif split in ['valid', 'test']:\n",
    "    data = self.valid if split == 'valid' else self.test\n",
    "    if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8', 'enwiki8']:\n",
    "        data_iter = LMOrderedIterator(data, *args, **kwargs)\n",
    "    elif self.dataset == 'lm1b':\n",
    "        data_iter = LMShuffledIterator(data, *args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9aff8a13",
   "metadata": {},
   "outputs": [],
   "source": [
    "i, (s, t, l) = list(enumerate(data_iter))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "9b122c41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 128]), torch.Size([1, 128]))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s.shape, t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aa64673",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b9dfb631",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "import torch\n",
    "import sys\n",
    "sys.path.append('pytorch')\n",
    "sys.path.append('pytorch/utils')\n",
    "\n",
    "from pytorch.mem_transformer import *\n",
    "from pytorch import data_utils\n",
    "\n",
    "from experiment_utils.run_experiment import *\n",
    "from experiment_utils.generate_data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "746e6986",
   "metadata": {},
   "source": [
    "## Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1de49706",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "TAG = '10tkn_len24_ext'\n",
    "\n",
    "TASK_NAME = 'copy'\n",
    "TRAIN_SIZE = 1000\n",
    "VAL_SIZE = 200\n",
    "TEST_SIZE = 100\n",
    "NUM_INITS = 1\n",
    "\n",
    "\n",
    "NUM_BATCHES = int(4e5)\n",
    "BATCH_SIZE = 128\n",
    "GENERATE_EVERY  = 10000\n",
    "NUM_TOKENS = 10 + 2\n",
    "ENC_SEQ_LEN = 24\n",
    "DEC_SEQ_LEN = 48\n",
    "\n",
    "INPUT_LEN = 24"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a87cb25",
   "metadata": {},
   "source": [
    "#### Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "31b7097a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # !mkdir data24\n",
    "# np.random.seed(42)\n",
    "\n",
    "# generator = copy_generator(batch_size=BATCH_SIZE, enc_seq_len=ENC_SEQ_LEN, dec_seq_len=DEC_SEQ_LEN, num_tokens=NUM_TOKENS)\n",
    "# generate_data(generator, path=f'data{INPUT_LEN}', task_name=TASK_NAME, train_size=TRAIN_SIZE, test_size=TEST_SIZE, val_size=VAL_SIZE, batch_size=BATCH_SIZE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "17253635",
   "metadata": {},
   "outputs": [],
   "source": [
    "class data_loader:\n",
    "    def __init__(self, task_name, path='data', batch_size=32, none_mask=True):\n",
    "        self.X, self.y = np.load(f'{path}/{task_name}_X.npy'), np.load(f'{path}/{task_name}_y.npy')\n",
    "        self.data_size = self.X.shape[0]\n",
    "        self.data_ptr = 0\n",
    "\n",
    "        if none_mask:\n",
    "            self.src_mask, self.tgt_mask = None, None\n",
    "        else:\n",
    "            self.src_masks, self.tgt_mask = np.load(f'{path}/{task_name}_mask.npy'), None\n",
    "\n",
    "        self.batch_size = batch_size\n",
    "        self.none_mask = none_mask\n",
    "\n",
    "    def __next__(self):\n",
    "        X = self.X[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "        y = self.y[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "        \n",
    "        if not self.none_mask:\n",
    "            sm = self.src_masks[self.data_ptr: self.data_ptr+self.batch_size]\n",
    "            sm = torch.tensor(sm).cuda()\n",
    "        else:\n",
    "            sm = None\n",
    "            \n",
    "        self.data_ptr = (self.data_ptr + self.batch_size) % self.data_size\n",
    "\n",
    "        return torch.tensor(X),\\\n",
    "                torch.tensor(y),\\\n",
    "                sm, self.tgt_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c118128",
   "metadata": {},
   "source": [
    "### Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "082c59ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_parameters = ParameterGrid({\n",
    "                'n_layer': [2],\n",
    "                'n_head': [4],\n",
    "                'd_head': [128],\n",
    "                'num_mem_tokens': [9, 0], \n",
    "                'mem_len': [0]})\n",
    "\n",
    "param = list(model_parameters)[0]\n",
    "\n",
    "fixed_parameters = {'n_token': NUM_TOKENS,\n",
    "                    'd_model': param['d_head'],# + param['num_mem_tokens']-1,\n",
    "                    'd_inner': param['d_head'],\n",
    "                    'dropout': 0,\n",
    "                    'dropatt': 0,\n",
    "                    'tie_weight': True,\n",
    "                    'div_val': 1, # ???????\n",
    "                    'tie_projs': [False],\n",
    "                    'tgt_len': DEC_SEQ_LEN,\n",
    "                    'ext_len': 0, \n",
    "                    'cutoffs': [],\n",
    "                    'attn_type': 0,}\n",
    "\n",
    "model = MemTransformerLM(**param, **fixed_parameters)#.cuda()\n",
    "\n",
    "gen_train = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_train', batch_size=21)\n",
    "src, tgt, _, _ = next(gen_train)\n",
    "# src, tgt = src.cuda(), tgt.cuda()\n",
    "src, tgt = src.cpu().T, tgt.cpu().T\n",
    "\n",
    "mems = tuple()\n",
    "# model(src, tgt.contiguous(), *mems)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d8bd627e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.mem_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "41a54127",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r_head_k tensor([[[-0.1610,  0.2663,  0.7317,  ..., -0.2776,  0.3494,  0.2219],\n",
      "         [-0.4773,  0.0291, -0.0052,  ...,  0.1036, -0.2353, -0.8195],\n",
      "         [-0.5486,  0.5329,  0.8749,  ..., -0.0964, -0.0407, -0.5861],\n",
      "         [-0.2102, -0.8608,  0.5378,  ...,  0.3381,  0.6476,  0.5246]],\n",
      "\n",
      "        [[-0.1123,  0.2265,  0.7049,  ..., -0.2572,  0.4042,  0.1958],\n",
      "         [-0.5291, -0.0010, -0.1122,  ...,  0.0622, -0.1191, -0.7520],\n",
      "         [-0.4054,  0.6335,  0.8337,  ..., -0.1049,  0.0407, -0.6469],\n",
      "         [-0.0840, -0.8060,  0.4163,  ...,  0.3796,  0.5944,  0.4686]],\n",
      "\n",
      "        [[-0.1388,  0.2334,  0.6685,  ..., -0.1457,  0.5046,  0.1005],\n",
      "         [-0.5478,  0.0256, -0.2161,  ...,  0.1366,  0.0666, -0.6371],\n",
      "         [-0.3217,  0.6339,  0.8744,  ..., -0.2130,  0.0851, -0.6712],\n",
      "         [-0.0076, -0.6470,  0.3289,  ...,  0.5306,  0.5669,  0.2955]]],\n",
      "       grad_fn=<SliceBackward>)\n",
      "r_head_k tensor([[[-0.1321, -0.1172,  0.4521,  ..., -0.1367, -0.1690,  0.0713],\n",
      "         [-0.5776,  0.0916,  0.1349,  ...,  0.1096, -0.3698,  0.2901],\n",
      "         [ 0.4843,  0.7358,  0.8618,  ...,  0.0190,  0.3507,  0.6138],\n",
      "         [-0.7395, -0.5627, -0.0489,  ..., -0.8122,  0.1539, -0.5622]],\n",
      "\n",
      "        [[-0.1288,  0.0712,  0.4355,  ..., -0.2548, -0.0871, -0.1787],\n",
      "         [-0.5107,  0.1085,  0.2819,  ...,  0.0491, -0.3565,  0.4212],\n",
      "         [ 0.4674,  0.7506,  0.8286,  ..., -0.1179,  0.2200,  0.5562],\n",
      "         [-0.8426, -0.3459,  0.0613,  ..., -0.6488,  0.0937, -0.5646]],\n",
      "\n",
      "        [[-0.1729,  0.1996,  0.3467,  ..., -0.2872,  0.0104, -0.3622],\n",
      "         [-0.4069,  0.1934,  0.4718,  ..., -0.1103, -0.3117,  0.5549],\n",
      "         [ 0.4537,  0.7325,  0.7082,  ..., -0.1618,  0.1851,  0.4399],\n",
      "         [-0.8449, -0.0798,  0.1570,  ..., -0.4112,  0.0400, -0.4244]]],\n",
      "       grad_fn=<SliceBackward>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ayd98/Desktop/MIPT/TXL/pytorch/mem_transformer.py:279: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at  /pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1104.)\n",
      "  attn_score = attn_score.float().masked_fill(\n"
     ]
    }
   ],
   "source": [
    "out, mems = model._forward(src[:, :1], mems=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "44163a43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([33, 1, 128])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "41c8e6b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-1.5600,  1.8477,  0.3658,  0.9825, -0.0475,  1.3295, -0.1483, -1.7067,\n",
       "          0.6093, -1.0218, -0.1678, -1.3956,  1.8529,  0.7607,  0.5487, -0.6880,\n",
       "          1.1657, -1.3241, -0.9002, -0.1482, -1.6440,  0.9515,  0.2808,  0.8881,\n",
       "          0.3642, -1.5069, -0.1935,  0.9324, -0.9401,  0.1613,  1.4500, -1.0660,\n",
       "         -0.4886,  0.6992,  0.9827,  1.3505,  0.0044,  0.8011,  0.8062, -0.0596,\n",
       "         -0.8604,  0.5002,  0.0251, -1.0938,  2.0272, -1.1475, -0.7047, -0.2692,\n",
       "          0.0681, -0.5827,  0.3875,  0.2979,  0.9092,  1.3300, -1.9834,  1.0851,\n",
       "         -1.7017, -0.6386,  0.3707,  0.4793, -1.6834,  0.1440, -1.2206,  0.6837,\n",
       "          0.6079, -1.7838, -0.2298, -1.5133, -3.4377, -0.6613,  0.1642, -0.2854,\n",
       "         -1.0116, -0.2610, -1.0397, -0.0208,  0.5359,  0.5490,  0.0329,  2.2298,\n",
       "         -0.3427,  0.2675,  0.5686,  1.7059, -0.0446, -0.6895, -0.4770, -0.4737,\n",
       "         -1.1829, -0.1803, -0.3275, -1.0151,  0.6898, -1.0112,  0.7736, -0.0343,\n",
       "          0.5162, -0.5593,  0.2612, -1.5110,  1.3973, -0.2930,  0.8542,  1.0782,\n",
       "         -0.8496,  1.3213,  0.7477, -0.6357,  1.7142,  0.9712,  0.2211, -1.0114,\n",
       "          0.2519,  0.1532,  0.0817,  0.6176,  1.7447,  0.5115, -1.0278,  0.4659,\n",
       "         -1.1791,  0.3442, -0.2401,  0.8124, -0.1939, -1.2341,  1.8422,  0.1750],\n",
       "        grad_fn=<SelectBackward>),\n",
       " tensor([-1.5600,  1.8477,  0.3658,  0.9825, -0.0475,  1.3295, -0.1483, -1.7067,\n",
       "          0.6093, -1.0218, -0.1678, -1.3956,  1.8529,  0.7607,  0.5487, -0.6880,\n",
       "          1.1657, -1.3241, -0.9002, -0.1482, -1.6440,  0.9515,  0.2808,  0.8881,\n",
       "          0.3642, -1.5069, -0.1935,  0.9324, -0.9401,  0.1613,  1.4500, -1.0660,\n",
       "         -0.4886,  0.6992,  0.9827,  1.3505,  0.0044,  0.8011,  0.8062, -0.0596,\n",
       "         -0.8604,  0.5002,  0.0251, -1.0938,  2.0272, -1.1475, -0.7047, -0.2692,\n",
       "          0.0681, -0.5827,  0.3875,  0.2979,  0.9092,  1.3300, -1.9834,  1.0851,\n",
       "         -1.7017, -0.6386,  0.3707,  0.4793, -1.6834,  0.1440, -1.2206,  0.6837,\n",
       "          0.6079, -1.7838, -0.2298, -1.5133, -3.4377, -0.6613,  0.1642, -0.2854,\n",
       "         -1.0116, -0.2610, -1.0397, -0.0208,  0.5359,  0.5490,  0.0329,  2.2298,\n",
       "         -0.3427,  0.2675,  0.5686,  1.7059, -0.0446, -0.6895, -0.4770, -0.4737,\n",
       "         -1.1829, -0.1803, -0.3275, -1.0151,  0.6898, -1.0112,  0.7736, -0.0343,\n",
       "          0.5162, -0.5593,  0.2612, -1.5110,  1.3973, -0.2930,  0.8542,  1.0782,\n",
       "         -0.8496,  1.3213,  0.7477, -0.6357,  1.7142,  0.9712,  0.2211, -1.0114,\n",
       "          0.2519,  0.1532,  0.0817,  0.6176,  1.7447,  0.5115, -1.0278,  0.4659,\n",
       "         -1.1791,  0.3442, -0.2401,  0.8124, -0.1939, -1.2341,  1.8422,  0.1750],\n",
       "        grad_fn=<SelectBackward>),\n",
       " tensor([ 2.3842e-07,  0.0000e+00,  5.9605e-08,  0.0000e+00, -4.2468e-07,\n",
       "          1.1921e-07, -4.4703e-08,  1.1921e-07,  0.0000e+00,  1.1921e-07,\n",
       "         -1.9372e-07,  1.1921e-07,  1.1921e-07,  5.9605e-08,  1.7881e-07,\n",
       "          0.0000e+00,  1.1921e-07,  1.1921e-07, -1.1921e-07,  1.4901e-08,\n",
       "          1.1921e-07,  0.0000e+00,  2.9802e-08, -1.7881e-07,  2.9802e-08,\n",
       "          1.1921e-07,  2.2352e-07, -1.1921e-07, -5.9605e-08,  1.3411e-07,\n",
       "         -1.1921e-07,  2.3842e-07, -1.1921e-07,  5.9605e-08, -5.9605e-08,\n",
       "          1.1921e-07,  9.5926e-08,  0.0000e+00,  1.7881e-07, -1.6019e-07,\n",
       "         -5.9605e-08,  5.9605e-08,  2.7940e-08, -2.3842e-07,  0.0000e+00,\n",
       "          0.0000e+00,  3.5763e-07,  8.9407e-08,  2.2352e-08, -1.1921e-07,\n",
       "         -2.9802e-08, -2.9802e-08,  5.9605e-08,  1.1921e-07,  0.0000e+00,\n",
       "          2.3842e-07,  1.1921e-07, -5.9605e-08,  5.9605e-08,  1.7881e-07,\n",
       "          1.1921e-07,  1.0431e-07,  0.0000e+00,  1.7881e-07, -1.1921e-07,\n",
       "         -2.3842e-07,  2.3842e-07, -1.1921e-07,  2.3842e-07, -1.7881e-07,\n",
       "          2.9802e-08,  0.0000e+00, -3.5763e-07,  8.9407e-08,  0.0000e+00,\n",
       "          3.3528e-08,  2.3842e-07,  5.9605e-08, -1.0431e-07,  0.0000e+00,\n",
       "         -2.9802e-08,  1.1921e-07, -5.9605e-08,  0.0000e+00,  1.4901e-08,\n",
       "         -2.3842e-07,  2.9802e-08,  1.1921e-07,  1.1921e-07,  2.9802e-08,\n",
       "         -1.1921e-07, -2.3842e-07,  0.0000e+00,  1.1921e-07,  0.0000e+00,\n",
       "          2.9802e-08,  0.0000e+00, -3.5763e-07,  0.0000e+00, -1.1921e-07,\n",
       "          1.1921e-07, -2.9802e-08, -1.7881e-07, -1.1921e-07,  5.9605e-08,\n",
       "         -1.1921e-07, -5.9605e-08, -1.7881e-07,  0.0000e+00, -1.7881e-07,\n",
       "         -8.9407e-08,  1.1921e-07, -5.9605e-08,  2.2352e-07,  7.4506e-08,\n",
       "          1.7881e-07, -1.1921e-07,  5.9605e-08, -1.1921e-07, -2.9802e-08,\n",
       "          1.1921e-07, -2.0862e-07, -1.6391e-07,  1.1921e-07, -1.4901e-07,\n",
       "         -2.3842e-07, -2.3842e-07,  1.4901e-08], grad_fn=<SubBackward0>))"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[0][0], out[1][0], out[0][0] - out[1][0] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1ca26955",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.3842e-07,  0.0000e+00,  5.9605e-08,  0.0000e+00, -4.2468e-07,\n",
       "         1.1921e-07, -4.4703e-08,  1.1921e-07,  0.0000e+00,  1.1921e-07,\n",
       "        -1.9372e-07,  1.1921e-07,  1.1921e-07,  5.9605e-08,  1.7881e-07,\n",
       "         0.0000e+00,  1.1921e-07,  1.1921e-07, -1.1921e-07,  1.4901e-08,\n",
       "         1.1921e-07,  0.0000e+00,  2.9802e-08, -1.7881e-07,  2.9802e-08,\n",
       "         1.1921e-07,  2.2352e-07, -1.1921e-07, -5.9605e-08,  1.3411e-07,\n",
       "        -1.1921e-07,  2.3842e-07, -1.1921e-07,  5.9605e-08, -5.9605e-08,\n",
       "         1.1921e-07,  9.5926e-08,  0.0000e+00,  1.7881e-07, -1.6019e-07,\n",
       "        -5.9605e-08,  5.9605e-08,  2.7940e-08, -2.3842e-07,  0.0000e+00,\n",
       "         0.0000e+00,  3.5763e-07,  8.9407e-08,  2.2352e-08, -1.1921e-07,\n",
       "        -2.9802e-08, -2.9802e-08,  5.9605e-08,  1.1921e-07,  0.0000e+00,\n",
       "         2.3842e-07,  1.1921e-07, -5.9605e-08,  5.9605e-08,  1.7881e-07,\n",
       "         1.1921e-07,  1.0431e-07,  0.0000e+00,  1.7881e-07, -1.1921e-07,\n",
       "        -2.3842e-07,  2.3842e-07, -1.1921e-07,  2.3842e-07, -1.7881e-07,\n",
       "         2.9802e-08,  0.0000e+00, -3.5763e-07,  8.9407e-08,  0.0000e+00,\n",
       "         3.3528e-08,  2.3842e-07,  5.9605e-08, -1.0431e-07,  0.0000e+00,\n",
       "        -2.9802e-08,  1.1921e-07, -5.9605e-08,  0.0000e+00,  1.4901e-08,\n",
       "        -2.3842e-07,  2.9802e-08,  1.1921e-07,  1.1921e-07,  2.9802e-08,\n",
       "        -1.1921e-07, -2.3842e-07,  0.0000e+00,  1.1921e-07,  0.0000e+00,\n",
       "         2.9802e-08,  0.0000e+00, -3.5763e-07,  0.0000e+00, -1.1921e-07,\n",
       "         1.1921e-07, -2.9802e-08, -1.7881e-07, -1.1921e-07,  5.9605e-08,\n",
       "        -1.1921e-07, -5.9605e-08, -1.7881e-07,  0.0000e+00, -1.7881e-07,\n",
       "        -8.9407e-08,  1.1921e-07, -5.9605e-08,  2.2352e-07,  7.4506e-08,\n",
       "         1.7881e-07, -1.1921e-07,  5.9605e-08, -1.1921e-07, -2.9802e-08,\n",
       "         1.1921e-07, -2.0862e-07, -1.6391e-07,  1.1921e-07, -1.4901e-07,\n",
       "        -2.3842e-07, -2.3842e-07,  1.4901e-08], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[0][0] - out[1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d8d619fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_tokens = None\n",
    "dec_inp = src[:, :2]\n",
    "\n",
    "self = model\n",
    "mems = None\n",
    "mem_tokens = None\n",
    "model.init_mem_tokens()\n",
    "# \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4921e112",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "word_emb = self.word_emb(dec_inp)\n",
    "\n",
    "mlen = mems[0].size(0) if mems is not None else 0\n",
    "\n",
    "# Concat with mem_tokens\n",
    "if mem_tokens is not None:\n",
    "    word_emb = torch.cat((mem_tokens.detach(), word_emb), dim=0)\n",
    "elif self.num_mem_tokens not in (0, None):\n",
    "    mem_tokens = self.mem_tokens.reshape(self.num_mem_tokens, 1, -1).repeat(1, dec_inp.shape[1], 1)\n",
    "    word_emb = torch.cat((mem_tokens.detach(), word_emb), dim=0)\n",
    "\n",
    "# qlen, bsz = dec_inp.size()\n",
    "qlen = word_emb.shape[0]\n",
    "klen = mlen + qlen\n",
    "if self.same_length:\n",
    "    all_ones = word_emb.new_ones(qlen, klen)\n",
    "    mask_len = klen - self.mem_len\n",
    "    if mask_len > 0:\n",
    "        mask_shift_len = qlen - mask_len\n",
    "    else:\n",
    "        mask_shift_len = qlen\n",
    "    dec_attn_mask = (torch.triu(all_ones, 1+mlen)\n",
    "            + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1\n",
    "else:\n",
    "    dec_attn_mask = torch.triu(\n",
    "        word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]\n",
    "\n",
    "hids = []\n",
    "if self.attn_type == 0: # default\n",
    "    pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, \n",
    "                            dtype=word_emb.dtype)\n",
    "    if self.clamp_len > 0:\n",
    "        pos_seq.clamp_(max=self.clamp_len)\n",
    "    pos_emb = self.pos_emb(pos_seq)\n",
    "\n",
    "    core_out = self.drop(word_emb)\n",
    "    pos_emb = self.drop(pos_emb)\n",
    "\n",
    "    hids.append(core_out)\n",
    "    # for i, layer in enumerate(self.layers):\n",
    "    #     mems_i = None if mems is None else mems[i]\n",
    "    #     core_out = layer(core_out, pos_emb, self.r_w_bias,\n",
    "    #             self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)\n",
    "    #     hids.append(core_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0e01ef2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([33, 2, 128])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start = core_out\n",
    "start.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7dbbcfa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "w, r, r_w_bias, r_r_bias, attn_mask, mems = start, pos_emb, self.r_w_bias, self.r_r_bias,  None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "978443ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r_head_k tensor([[[-0.1610,  0.2663,  0.7317,  ..., -0.2776,  0.3494,  0.2219],\n",
      "         [-0.4773,  0.0291, -0.0052,  ...,  0.1036, -0.2353, -0.8195],\n",
      "         [-0.5486,  0.5329,  0.8749,  ..., -0.0964, -0.0407, -0.5861],\n",
      "         [-0.2102, -0.8608,  0.5378,  ...,  0.3381,  0.6476,  0.5246]],\n",
      "\n",
      "        [[-0.1123,  0.2265,  0.7049,  ..., -0.2572,  0.4042,  0.1958],\n",
      "         [-0.5291, -0.0010, -0.1122,  ...,  0.0622, -0.1191, -0.7520],\n",
      "         [-0.4054,  0.6335,  0.8337,  ..., -0.1049,  0.0407, -0.6469],\n",
      "         [-0.0840, -0.8060,  0.4163,  ...,  0.3796,  0.5944,  0.4686]],\n",
      "\n",
      "        [[-0.1388,  0.2334,  0.6685,  ..., -0.1457,  0.5046,  0.1005],\n",
      "         [-0.5478,  0.0256, -0.2161,  ...,  0.1366,  0.0666, -0.6371],\n",
      "         [-0.3217,  0.6339,  0.8744,  ..., -0.2130,  0.0851, -0.6712],\n",
      "         [-0.0076, -0.6470,  0.3289,  ...,  0.5306,  0.5669,  0.2955]]],\n",
      "       grad_fn=<SliceBackward>)\n"
     ]
    }
   ],
   "source": [
    "res = model.layers[0].dec_attn(start, pos_emb, self.r_w_bias, self.r_r_bias,  None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5d41f826",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.4400, -0.2063, -0.1463,  ...,  0.3321, -1.0940, -0.4041],\n",
       "         [-0.8133,  0.4882, -0.3307,  ...,  0.4120, -1.5649,  0.0682]],\n",
       "\n",
       "        [[-0.4213, -0.2204, -0.1544,  ...,  0.3347, -1.0822, -0.4061],\n",
       "         [-0.8187,  0.4928, -0.3231,  ...,  0.4113, -1.5746,  0.0631]],\n",
       "\n",
       "        [[-0.3957, -0.2444, -0.1613,  ...,  0.3342, -1.0687, -0.4108],\n",
       "         [-0.8310,  0.4951, -0.3280,  ...,  0.3992, -1.5953,  0.0467]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[-0.1819, -0.6903,  0.4016,  ...,  0.3736,  0.2657, -1.9263],\n",
       "         [ 0.6139, -0.6959, -1.4468,  ...,  0.5378, -0.5685, -1.7276]],\n",
       "\n",
       "        [[-1.0107, -0.7791, -2.1517,  ...,  1.3996,  0.1962,  0.2709],\n",
       "         [-1.3758,  0.0292, -1.4295,  ..., -0.0346, -0.0155, -0.0834]],\n",
       "\n",
       "        [[-1.6117, -1.1003, -0.2831,  ..., -0.9778, -1.4025,  0.3320],\n",
       "         [-0.2470, -0.1687,  0.0569,  ...,  0.2756,  0.2057, -0.2320]]],\n",
       "       grad_fn=<NativeLayerNormBackward>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be9dfc1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)\n",
    "qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "5ffed20a",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'MemTransformerLM' object has no attribute 'pre_lnorm'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-34-60ad793c5ac9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0mw_head_q\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mw_head_q\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mqlen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m     \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpre_lnorm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m         \u001b[0mw_heads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/cudaenv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1128\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1129\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m   1131\u001b[0m             type(self).__name__, name))\n\u001b[1;32m   1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'MemTransformerLM' object has no attribute 'pre_lnorm'"
     ]
    }
   ],
   "source": [
    "qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)\n",
    "\n",
    "if mems is not None:\n",
    "    cat = torch.cat([mems, w], 0)\n",
    "    if self.pre_lnorm:\n",
    "        w_heads = self.qkv_net(self.layer_norm(cat))\n",
    "    else:\n",
    "        w_heads = self.qkv_net(cat)\n",
    "    r_head_k = self.r_net(r)\n",
    "    # print(r_head_k.shape, w.shape, cat.shape, r_head_k)\n",
    "\n",
    "    w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)\n",
    "    w_head_q = w_head_q[-qlen:]\n",
    "else:\n",
    "    if self.pre_lnorm:\n",
    "        w_heads = self.qkv_net(self.layer_norm(w))\n",
    "    else:\n",
    "        w_heads = self.qkv_net(w)\n",
    "    r_head_k = self.r_net(r)\n",
    "    # print(r_head_k.shape, w.shape, r.shape, r_head_k)\n",
    "\n",
    "    w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)\n",
    "\n",
    "klen = w_head_k.size(0)\n",
    "\n",
    "w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head\n",
    "w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head\n",
    "w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head\n",
    "\n",
    "r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head\n",
    "\n",
    "#### compute attention score\n",
    "rw_head_q = w_head_q + r_w_bias \n",
    "# print('rw_head_q',rw_head_q)                                        # qlen x bsz x n_head x d_head\n",
    "# print('w head k', w_head_k)\n",
    "# # print('rw bias', r_w_bias)\n",
    "AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head\n",
    "\n",
    "rr_head_q = w_head_q + r_r_bias\n",
    "BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head\n",
    "# print('rr_head_q', rr_head_q[:3])\n",
    "print('r_head_k', r_head_k[:3])\n",
    "# print('BD ', BD[:3])\n",
    "# print(BD[0] - BD[1])\n",
    "# print(BD[0][0] - BD[0][1])\n",
    "BD = self._rel_shift(BD)\n",
    "# print('BD ', BD)\n",
    "\n",
    "# [qlen x klen x bsz x n_head]\n",
    "attn_score = AC + BD\n",
    "attn_score.mul_(self.scale)\n",
    "print('attn score ', attn_score)\n",
    "\n",
    "#### compute attention probability\n",
    "if attn_mask is not None and attn_mask.any().item():\n",
    "    if attn_mask.dim() == 2:\n",
    "        attn_score = attn_score.float().masked_fill(\n",
    "            attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)\n",
    "    elif attn_mask.dim() == 3:\n",
    "        attn_score = attn_score.float().masked_fill(\n",
    "            attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)\n",
    "\n",
    "# [qlen x klen x bsz x n_head]\n",
    "attn_prob = F.softmax(attn_score, dim=1)\n",
    "attn_prob = self.dropatt(attn_prob)\n",
    "\n",
    "#### compute attention vector\n",
    "attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))\n",
    "\n",
    "# [qlen x bsz x n_head x d_head]\n",
    "attn_vec = attn_vec.contiguous().view(\n",
    "    attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)\n",
    "\n",
    "##### linear projection\n",
    "attn_out = self.o_net(attn_vec)\n",
    "attn_out = self.drop(attn_out)\n",
    "\n",
    "if self.pre_lnorm:\n",
    "    ##### residual connection\n",
    "    output = w + attn_out\n",
    "else:\n",
    "    ##### residual connection + layer normalization\n",
    "    output = self.layer_norm(w + attn_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "7d075378",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'AC' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-33-c8b725298fd8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mAC\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'AC' is not defined"
     ]
    }
   ],
   "source": [
    "AC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "572ae2b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "r_head_k tensor([[[ 0.2262, -0.5403,  0.0405,  ...,  0.0798,  0.0777,  1.1830],\n",
      "         [ 0.5002, -0.2644,  0.0691,  ..., -0.0247, -0.0708,  0.0897],\n",
      "         [-0.3503,  0.0289, -0.5878,  ..., -0.2117, -0.1172, -0.1238],\n",
      "         [ 0.5241, -0.5015, -0.0266,  ...,  0.2789,  0.2498, -0.1229]],\n",
      "\n",
      "        [[ 0.2521, -0.5876, -0.0805,  ...,  0.0289,  0.2788,  1.2227],\n",
      "         [ 0.6034, -0.1618,  0.0834,  ..., -0.1337, -0.1154,  0.1049],\n",
      "         [-0.2763, -0.2468, -0.5324,  ..., -0.1908, -0.2773, -0.0356],\n",
      "         [ 0.5368, -0.3575, -0.0864,  ...,  0.3076,  0.2850, -0.1621]],\n",
      "\n",
      "        [[ 0.3549, -0.5847, -0.0754,  ..., -0.1469,  0.3942,  1.1786],\n",
      "         [ 0.6535,  0.0195,  0.0049,  ..., -0.2342, -0.1317,  0.0346],\n",
      "         [-0.1697, -0.4524, -0.5574,  ..., -0.1915, -0.4983,  0.0562],\n",
      "         [ 0.5492, -0.2864, -0.1481,  ...,  0.2096,  0.4107, -0.1843]]],\n",
      "       grad_fn=<SliceBackward>)\n",
      "attn score  tensor([[[[ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01],\n",
      "          [ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01]],\n",
      "\n",
      "         [[ 1.6746e-01, -3.2276e-01,  3.9268e-01, -6.4377e-02],\n",
      "          [ 1.6746e-01, -3.2276e-01,  3.9268e-01, -6.4377e-02]],\n",
      "\n",
      "         [[ 2.6527e-01, -6.9557e-01,  2.7688e-01,  3.2766e-01],\n",
      "          [ 2.6527e-01, -6.9557e-01,  2.7688e-01,  3.2766e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 6.9418e-01, -1.3079e+00, -1.7586e+00,  3.6488e+00],\n",
      "          [ 3.0289e-02, -6.4190e+00,  3.8662e+00,  1.4831e+00]],\n",
      "\n",
      "         [[-2.0986e+00, -9.9517e-02, -4.7451e+00,  4.5032e+00],\n",
      "          [ 3.2633e+00, -3.3871e+00, -6.9066e-01,  1.5513e+00]],\n",
      "\n",
      "         [[-7.2892e+00,  3.9491e+00,  1.7186e+00,  1.2491e+00],\n",
      "          [ 4.2058e+00, -3.4959e+00,  4.7210e-01, -4.1301e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 4.4663e-01, -8.1422e-01,  5.0567e-01,  2.7344e-01],\n",
      "          [ 4.4663e-01, -8.1422e-01,  5.0567e-01,  2.7344e-01]],\n",
      "\n",
      "         [[ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01],\n",
      "          [ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01]],\n",
      "\n",
      "         [[ 1.6746e-01, -3.2276e-01,  3.9268e-01, -6.4377e-02],\n",
      "          [ 1.6746e-01, -3.2276e-01,  3.9268e-01, -6.4377e-02]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 6.6262e-01, -1.2501e+00, -1.8373e+00,  3.6490e+00],\n",
      "          [-1.2706e-03, -6.3612e+00,  3.7876e+00,  1.4833e+00]],\n",
      "\n",
      "         [[-2.1634e+00, -4.6209e-02, -4.7932e+00,  4.4596e+00],\n",
      "          [ 3.1984e+00, -3.3338e+00, -7.3880e-01,  1.5078e+00]],\n",
      "\n",
      "         [[-7.3600e+00,  3.9653e+00,  1.7531e+00,  1.2230e+00],\n",
      "          [ 4.1350e+00, -3.4797e+00,  5.0658e-01, -4.1562e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 3.9922e-01, -8.4769e-01,  6.0729e-01,  3.1375e-01],\n",
      "          [ 3.9922e-01, -8.4769e-01,  6.0729e-01,  3.1375e-01]],\n",
      "\n",
      "         [[ 4.4663e-01, -8.1422e-01,  5.0567e-01,  2.7344e-01],\n",
      "          [ 4.4663e-01, -8.1422e-01,  5.0567e-01,  2.7344e-01]],\n",
      "\n",
      "         [[ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01],\n",
      "          [ 4.5926e-01, -7.4895e-01,  4.0674e-01,  1.8538e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 6.6724e-01, -1.2159e+00, -1.8698e+00,  3.7028e+00],\n",
      "          [ 3.3577e-03, -6.3270e+00,  3.7551e+00,  1.5372e+00]],\n",
      "\n",
      "         [[-2.1950e+00,  1.1606e-02, -4.8719e+00,  4.4598e+00],\n",
      "          [ 3.1669e+00, -3.2760e+00, -8.1745e-01,  1.5079e+00]],\n",
      "\n",
      "         [[-7.4249e+00,  4.0186e+00,  1.7050e+00,  1.1794e+00],\n",
      "          [ 4.0701e+00, -3.4264e+00,  4.5845e-01, -4.1997e+00]]],\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "        [[[ 3.9344e+00, -1.6987e+01,  8.2212e-01,  2.6925e-01],\n",
      "          [ 1.8994e+00, -1.4906e+00, -1.7139e+00, -3.9948e+00]],\n",
      "\n",
      "         [[ 3.2679e+00, -1.7256e+01,  7.2213e-01,  7.7800e-01],\n",
      "          [ 1.7202e+00, -1.7654e+00, -2.7328e+00, -3.7180e+00]],\n",
      "\n",
      "         [[ 2.5407e+00, -1.8179e+01,  1.6145e+00,  1.2467e+00],\n",
      "          [ 2.3142e+00, -3.0812e+00, -3.4293e+00, -3.4361e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[-3.9185e+01,  1.5453e+01,  6.1362e+00,  3.2932e+01],\n",
      "          [ 3.2284e+01,  2.4696e+01,  2.6776e+01, -4.0733e+00]],\n",
      "\n",
      "         [[-1.0713e+01, -6.9300e+01,  3.2201e+01,  1.6425e+01],\n",
      "          [-2.4703e+01,  2.5067e+01, -1.8456e+01,  3.9301e+01]],\n",
      "\n",
      "         [[ 2.0539e+01, -2.7520e+01,  5.8287e+01, -4.8176e+01],\n",
      "          [ 2.6689e+01,  4.6780e+01,  6.3200e+00,  1.1838e+02]]],\n",
      "\n",
      "\n",
      "        [[[ 3.1445e+00, -1.0125e+01,  2.7909e+00,  3.2318e+00],\n",
      "          [ 1.4233e+00,  1.3477e+00, -1.9102e+00, -5.0790e+00]],\n",
      "\n",
      "         [[ 3.5956e+00, -1.0074e+01,  3.0350e+00,  3.6104e+00],\n",
      "          [ 1.8009e+00,  6.2681e-01, -1.3466e+00, -3.9930e+00]],\n",
      "\n",
      "         [[ 3.8502e+00, -9.8174e+00,  3.4652e+00,  3.9425e+00],\n",
      "          [ 2.4303e+00,  5.7494e-01, -1.4234e+00, -3.1369e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 5.9431e+01, -5.9329e+01,  3.2711e+01,  2.8909e+01],\n",
      "          [ 3.0623e+00, -1.2856e+01, -5.3952e+01, -5.7508e+01]],\n",
      "\n",
      "         [[-4.3660e+01,  1.6798e+01,  3.8689e+01, -3.8569e+01],\n",
      "          [ 2.8029e+01, -2.6512e+01,  3.4649e+01, -7.1437e+00]],\n",
      "\n",
      "         [[-2.0008e+01, -9.0858e+00,  5.6368e+01, -9.2771e+00],\n",
      "          [ 5.7385e+01,  5.3339e+00, -3.0548e+01, -1.4880e+01]]],\n",
      "\n",
      "\n",
      "        [[[ 2.4009e+00,  5.7452e+00, -3.5478e+00, -4.0724e-01],\n",
      "          [ 1.6064e+00,  6.6179e+00, -2.5703e+00,  2.2452e+00]],\n",
      "\n",
      "         [[ 1.5347e+00,  6.0475e+00, -3.7623e+00, -1.1660e-02],\n",
      "          [ 1.5736e+00,  6.2382e+00, -3.0802e+00,  1.5857e+00]],\n",
      "\n",
      "         [[ 1.7910e+00,  5.8255e+00, -3.7303e+00, -1.2666e-01],\n",
      "          [ 2.3877e+00,  5.7367e+00, -3.2145e+00,  6.5799e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 1.3110e+01, -7.6679e+01, -1.4859e+01,  4.8115e+01],\n",
      "          [-1.3063e+00, -4.4439e+01, -4.0180e+01, -5.6466e+01]],\n",
      "\n",
      "         [[-8.1210e+00,  2.4731e+01,  7.7767e+01,  2.9580e+01],\n",
      "          [-2.1810e+01, -1.7348e+01,  2.1157e+01,  2.0215e+00]],\n",
      "\n",
      "         [[ 4.2586e+01, -6.4056e+01, -2.3489e+01, -2.7994e+01],\n",
      "          [-1.1617e+01, -7.6856e+00,  3.4243e+01,  4.9952e+00]]]],\n",
      "       grad_fn=<MulBackward0>)\n",
      "r_head_k tensor([[[ 0.0376,  0.5418,  0.0020,  ..., -0.0733, -0.3008,  0.5742],\n",
      "         [ 0.5026,  0.5347, -0.4856,  ...,  0.0093,  0.2657, -1.0471],\n",
      "         [ 0.2351,  0.3034, -0.2594,  ..., -0.5405, -0.3086,  0.1489],\n",
      "         [-0.3599,  0.0196,  0.1991,  ...,  0.0716, -0.0530,  0.3230]],\n",
      "\n",
      "        [[-0.1133,  0.3320, -0.0076,  ..., -0.0600, -0.1846,  0.7320],\n",
      "         [ 0.4232,  0.5023, -0.6066,  ...,  0.0614,  0.3908, -0.8849],\n",
      "         [ 0.2550,  0.2186, -0.3656,  ..., -0.4737, -0.3037,  0.0724],\n",
      "         [-0.1894,  0.1493,  0.0934,  ...,  0.2659, -0.0306,  0.4982]],\n",
      "\n",
      "        [[-0.1270,  0.0502, -0.0620,  ..., -0.0760, -0.0190,  0.8129],\n",
      "         [ 0.3343,  0.4200, -0.6106,  ...,  0.1099,  0.4562, -0.7498],\n",
      "         [ 0.1582,  0.1719, -0.4557,  ..., -0.3775, -0.2726,  0.1561],\n",
      "         [-0.0246,  0.1821, -0.1044,  ...,  0.3741,  0.0095,  0.5820]]],\n",
      "       grad_fn=<SliceBackward>)\n",
      "attn score  tensor([[[[-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01],\n",
      "          [-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01]],\n",
      "\n",
      "         [[-1.3572e-01,  4.8241e-01, -5.6016e-01, -4.1970e-01],\n",
      "          [-1.3572e-01,  4.8241e-01, -5.6016e-01, -4.1970e-01]],\n",
      "\n",
      "         [[-5.5069e-02,  8.7283e-01, -1.1392e+00, -2.1891e-01],\n",
      "          [-5.5069e-02,  8.7283e-01, -1.1392e+00, -2.1891e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[-4.6686e-01,  6.1955e-01, -4.0405e-01,  5.0529e-01],\n",
      "          [ 5.2529e-01, -1.9901e-01,  1.0883e-01,  4.6836e-01]],\n",
      "\n",
      "         [[-1.2102e-01,  1.9040e-01, -6.6272e-01,  4.3710e-01],\n",
      "          [-8.7621e-01,  3.4843e-01, -6.0371e-01,  1.2740e-01]],\n",
      "\n",
      "         [[-2.3267e-01,  3.5128e-01, -4.0467e-01,  3.0937e-01],\n",
      "          [-4.2602e-01, -4.1877e-01, -1.8211e-01,  3.7009e-01]]],\n",
      "\n",
      "\n",
      "        [[[-2.7259e-01,  4.8051e-01, -1.1058e+00, -9.5562e-02],\n",
      "          [-2.7259e-01,  4.8051e-01, -1.1058e+00, -9.5562e-02]],\n",
      "\n",
      "         [[-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01],\n",
      "          [-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01]],\n",
      "\n",
      "         [[-1.3572e-01,  4.8241e-01, -5.6016e-01, -4.1970e-01],\n",
      "          [-1.3572e-01,  4.8241e-01, -5.6016e-01, -4.1970e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[-3.9939e-01,  5.8459e-01, -3.2497e-01,  5.2258e-01],\n",
      "          [ 5.9276e-01, -2.3397e-01,  1.8791e-01,  4.8565e-01]],\n",
      "\n",
      "         [[-1.1930e-01,  2.2887e-01, -5.2682e-01,  4.3972e-01],\n",
      "          [-8.7449e-01,  3.8690e-01, -4.6781e-01,  1.3001e-01]],\n",
      "\n",
      "         [[-3.0852e-01,  4.4559e-01, -2.9955e-01,  3.0465e-01],\n",
      "          [-5.0188e-01, -3.2446e-01, -7.6993e-02,  3.6537e-01]]],\n",
      "\n",
      "\n",
      "        [[[-3.7374e-01,  5.7660e-01, -1.1003e+00, -8.4738e-02],\n",
      "          [-3.7374e-01,  5.7660e-01, -1.1003e+00, -8.4738e-02]],\n",
      "\n",
      "         [[-2.7259e-01,  4.8051e-01, -1.1058e+00, -9.5562e-02],\n",
      "          [-2.7259e-01,  4.8051e-01, -1.1058e+00, -9.5562e-02]],\n",
      "\n",
      "         [[-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01],\n",
      "          [-2.2884e-01,  4.3547e-01, -1.0116e+00, -1.3271e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[-3.2180e-01,  5.0689e-01, -3.4120e-01,  5.3789e-01],\n",
      "          [ 6.7035e-01, -3.1168e-01,  1.7168e-01,  5.0096e-01]],\n",
      "\n",
      "         [[-5.1834e-02,  1.9391e-01, -4.4774e-01,  4.5701e-01],\n",
      "          [-8.0702e-01,  3.5194e-01, -3.8873e-01,  1.4730e-01]],\n",
      "\n",
      "         [[-3.0681e-01,  4.8406e-01, -1.6365e-01,  3.0726e-01],\n",
      "          [-5.0017e-01, -2.8599e-01,  5.8904e-02,  3.6798e-01]]],\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "        [[[ 1.9624e-01,  7.8250e-01,  7.1541e-02, -1.4807e-01],\n",
      "          [ 6.2480e-01,  2.3242e-01, -2.8722e-01,  7.3112e-02]],\n",
      "\n",
      "         [[ 1.7278e-01,  7.8999e-01,  1.1641e-01, -2.2632e-01],\n",
      "          [ 6.7263e-01,  1.9593e-01, -2.7226e-01,  2.4658e-02]],\n",
      "\n",
      "         [[ 1.5072e-01,  7.8929e-01,  1.9564e-01, -2.9874e-01],\n",
      "          [ 6.8149e-01,  1.0850e-01, -2.8036e-01, -3.8156e-02]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 1.4664e-01,  2.4456e-02,  7.9919e-01, -3.0577e-01],\n",
      "          [ 2.0518e-01, -4.3200e-02, -2.3106e-01, -1.1075e-03]],\n",
      "\n",
      "         [[-4.1230e-01, -1.2259e-02, -4.2129e-01,  3.6297e-02],\n",
      "          [-4.3176e-01, -4.9160e-01, -1.7472e-01,  2.0604e-01]],\n",
      "\n",
      "         [[-2.6241e-01, -5.3381e-01, -7.8697e-02, -3.0808e-01],\n",
      "          [-8.6617e-02, -7.4793e-01,  3.0748e-03,  1.9606e-01]]],\n",
      "\n",
      "\n",
      "        [[[-4.6318e-01, -2.9306e-02,  3.2402e-01, -1.5487e-01],\n",
      "          [ 4.8165e-01, -3.7710e-01,  4.5637e-01,  4.0392e-01]],\n",
      "\n",
      "         [[-4.5778e-01, -2.7279e-02,  2.6631e-01, -1.4632e-01],\n",
      "          [ 4.6161e-01, -3.9188e-01,  5.7357e-01,  4.8742e-01]],\n",
      "\n",
      "         [[-5.0183e-01, -4.5212e-02,  2.2513e-01, -1.7105e-01],\n",
      "          [ 4.6193e-01, -3.6174e-01,  6.2449e-01,  5.5347e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[-3.7469e-01, -2.6464e-01,  4.8987e-01, -2.3905e-01],\n",
      "          [ 4.5030e-02, -2.8319e-01,  3.6075e-01,  4.5396e-01]],\n",
      "\n",
      "         [[-6.7070e-02, -5.0136e-02,  8.2176e-01, -3.3554e-01],\n",
      "          [-3.7264e-01, -7.0634e-01,  5.3794e-01, -1.3859e-01]],\n",
      "\n",
      "         [[-4.0102e-01, -2.2375e-02, -7.4985e-02,  2.8321e-01],\n",
      "          [-2.3138e-01, -6.5743e-02,  6.8930e-01, -1.8139e-01]]],\n",
      "\n",
      "\n",
      "        [[[-2.6197e-01,  1.6456e-01, -3.8674e-01, -2.4953e-01],\n",
      "          [ 4.6812e-01,  4.2299e-01, -1.2770e-01,  4.1488e-01]],\n",
      "\n",
      "         [[-2.5378e-01,  1.4974e-01, -3.6693e-01, -2.0536e-01],\n",
      "          [ 4.4487e-01,  3.8013e-01, -2.3005e-01,  3.7590e-01]],\n",
      "\n",
      "         [[-1.7349e-01,  1.4437e-01, -2.9268e-01, -1.7047e-01],\n",
      "          [ 4.2580e-01,  3.2135e-01, -3.1640e-01,  3.6216e-01]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[ 1.0874e-01,  1.3506e-01, -5.3571e-02,  5.8103e-01],\n",
      "          [-3.2715e-01, -1.8171e-01,  2.9960e-01,  4.8359e-01]],\n",
      "\n",
      "         [[ 2.6381e-01, -3.2820e-01,  1.5716e-01,  3.3972e-01],\n",
      "          [ 3.1592e-01, -2.6581e-01, -4.1753e-01,  3.5095e-01]],\n",
      "\n",
      "         [[ 5.4988e-02, -3.8589e-01, -3.4877e-01,  2.7734e-01],\n",
      "          [ 1.2285e-01, -1.7753e-01, -3.3278e-01,  6.0755e-01]]]],\n",
      "       grad_fn=<MulBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ayd98/Desktop/MIPT/TXL/pytorch/mem_transformer.py:279: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at  /pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1104.)\n",
      "  attn_score = attn_score.float().masked_fill(\n"
     ]
    }
   ],
   "source": [
    "out = model(src[:, :2], tgt[:24, :2].contiguous())\n",
    "\n",
    "mem_tokens, enc, mems = out[0], out[1], out[2:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c9fa988",
   "metadata": {},
   "outputs": [
    {
     "ename": "ZeroDivisionError",
     "evalue": "division by zero",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mZeroDivisionError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-9e1622b385b6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;36m1\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mZeroDivisionError\u001b[0m: division by zero"
     ]
    }
   ],
   "source": [
    "1/0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba1ea8fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.mem_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1777914",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _forward(self, dec_inp, mems=None, mem_tokens=None):\n",
    "\n",
    "        word_emb = self.word_emb(dec_inp)\n",
    "\n",
    "        mlen = mems[0].size(0) if mems is not None else 0\n",
    "        \n",
    "        # Concat with mem_tokens\n",
    "        if mem_tokens is not None:\n",
    "            word_emb = torch.cat((mem_tokens.detach(), word_emb), dim=0)\n",
    "        elif self.num_mem_tokens not in (0, None):\n",
    "            mem_tokens = self.mem_tokens.reshape(self.num_mem_tokens, 1, -1).repeat(1, dec_inp.shape[1], 1)\n",
    "            word_emb = torch.cat((mem_tokens.detach(), word_emb), dim=0)\n",
    "\n",
    "        # qlen, bsz = dec_inp.size()\n",
    "        qlen = word_emb.shape[0]\n",
    "        klen = mlen + qlen\n",
    "        if self.same_length:\n",
    "            all_ones = word_emb.new_ones(qlen, klen)\n",
    "            mask_len = klen - self.mem_len\n",
    "            if mask_len > 0:\n",
    "                mask_shift_len = qlen - mask_len\n",
    "            else:\n",
    "                mask_shift_len = qlen\n",
    "            dec_attn_mask = (torch.triu(all_ones, 1+mlen)\n",
    "                    + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1\n",
    "        else:\n",
    "            dec_attn_mask = torch.triu(\n",
    "                word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]\n",
    "        \n",
    "        hids = []\n",
    "        if self.attn_type == 0: # default\n",
    "            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, \n",
    "                                   dtype=word_emb.dtype)\n",
    "            if self.clamp_len > 0:\n",
    "                pos_seq.clamp_(max=self.clamp_len)\n",
    "            pos_emb = self.pos_emb(pos_seq)\n",
    "\n",
    "            core_out = self.drop(word_emb)\n",
    "            pos_emb = self.drop(pos_emb)\n",
    "            print(len(core_out), [c.shape for c in core_out])\n",
    "\n",
    "            hids.append(core_out)\n",
    "            for i, layer in enumerate(self.layers):\n",
    "                mems_i = None if mems is None else mems[i]\n",
    "                core_out = layer(core_out, pos_emb, self.r_w_bias,\n",
    "                        self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)\n",
    "                hids.append(core_out)\n",
    "        elif self.attn_type == 1: # learnable\n",
    "            core_out = self.drop(word_emb)\n",
    "            hids.append(core_out)\n",
    "            for i, layer in enumerate(self.layers):\n",
    "                if self.clamp_len > 0:\n",
    "                    r_emb = self.r_emb[i][-self.clamp_len :]\n",
    "                    r_bias = self.r_bias[i][-self.clamp_len :]\n",
    "                else:\n",
    "                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]\n",
    "\n",
    "                mems_i = None if mems is None else mems[i]\n",
    "                core_out = layer(core_out, r_emb, self.r_w_bias[i],\n",
    "                        r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)\n",
    "                hids.append(core_out)\n",
    "        elif self.attn_type == 2: # absolute\n",
    "            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,\n",
    "                                   dtype=word_emb.dtype)\n",
    "            if self.clamp_len > 0:\n",
    "                pos_seq.clamp_(max=self.clamp_len)\n",
    "            pos_emb = self.pos_emb(pos_seq)\n",
    "\n",
    "            core_out = self.drop(word_emb + pos_emb[-qlen:])\n",
    "\n",
    "            hids.append(core_out)\n",
    "            for i, layer in enumerate(self.layers):\n",
    "                mems_i = None if mems is None else mems[i]\n",
    "                if mems_i is not None and i == 0:\n",
    "                    mems_i += pos_emb[:mlen]\n",
    "                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,\n",
    "                                 mems=mems_i)\n",
    "                hids.append(core_out)\n",
    "        elif self.attn_type == 3:\n",
    "            core_out = self.drop(word_emb)\n",
    "\n",
    "            hids.append(core_out)\n",
    "            for i, layer in enumerate(self.layers):\n",
    "                mems_i = None if mems is None else mems[i]\n",
    "                if mems_i is not None and mlen > 0:\n",
    "                    cur_emb = self.r_emb[i][:-qlen]\n",
    "                    cur_size = cur_emb.size(0)\n",
    "                    if cur_size < mlen:\n",
    "                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)\n",
    "                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)\n",
    "                    else:\n",
    "                        cur_emb = cur_emb[-mlen:]\n",
    "                    mems_i += cur_emb.view(mlen, 1, -1)\n",
    "                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)\n",
    "\n",
    "                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,\n",
    "                                 mems=mems_i)\n",
    "                hids.append(core_out)\n",
    "\n",
    "        core_out = self.drop(core_out)\n",
    "\n",
    "        new_mems = self._update_mems(hids, mems, mlen, qlen)\n",
    "\n",
    "        return core_out, new_mems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5914e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    " _forward(model, src[:, :1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d49d90e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "out, mems = _forward(model, src[:, :1])\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "987e9ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "out[:9]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "509e1a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_tokens[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9a4ea8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_tokens[0][0] - mem_tokens[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c369a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8889546d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem.shape, pred_hid.shape, tgt.contiguous().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f06e172",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = model.crit(pred_hid.view(-1, pred_hid.size(-1)), tgt.contiguous().view(-1))\n",
    "loss = loss.view(tgt_len, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e8b91e",
   "metadata": {},
   "outputs": [],
   "source": [
    "src.shape, out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9896f597",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerXL(nn.Module):\n",
    "    def __init__(self, enc_kwargs, dec_kwargs):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.Encoder = MemTransformerLM(**enc_kwargs)\n",
    "        self.Decoder = MemTransformerLM(**dec_kwargs)\n",
    "\n",
    "        self.enc_kwargs = enc_kwargs\n",
    "        self.dec_kwargs = dec_kwargs\n",
    "\n",
    "    def forward(self, src, tgt, mems=None):\n",
    "        enc, mems = self.Encoder._forward(src, mems=mems)\n",
    "\n",
    "        num_mem = self.enc_kwargs['num_mem_tokens']\n",
    "        if num_mem not in (None, 0):\n",
    "            mem, enc = enc[:num_mem], enc[num_mem:]\n",
    "\n",
    "        if self.dec_kwargs['mem_len'] == 0:\n",
    "            mems = None\n",
    "\n",
    "\n",
    "        start_tokens = tgt[:1, :]\n",
    "        out, new_mems = self.Decoder._forward(start_tokens)\n",
    "        \n",
    "        # loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fd977b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ba37da5",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "738394cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "tgt.contiguous().view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "467618bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# rep = src[0].repeat(16, 1)\n",
    "# hidden, new_mems = model._forward(rep)\n",
    "\n",
    "mems = None\n",
    "# dec_inp = src[0].repeat(21, 1).T\n",
    "dec_inp = src[:21].T\n",
    "self = model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa644df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "word_emb = self.word_emb(dec_inp)\n",
    "\n",
    "mlen = mems[0].size(0) if mems is not None else 0\n",
    "\n",
    "# Concat with mem_tokens\n",
    "if self.num_mem_tokens not in (0, None):\n",
    "    # memory = self.mem_tokens.repeat(1, dec_inp.shape[0], 1).clone()\n",
    "    memory = self.mem_tokens.reshape(self.num_mem_tokens, 1, -1).repeat(1, dec_inp.shape[1], 1)\n",
    "    word_emb = torch.cat((memory, word_emb), dim=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f6344ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# qlen, bsz = dec_inp.size()\n",
    "qlen = word_emb.shape[0]\n",
    "klen = mlen + qlen\n",
    "if self.same_length:\n",
    "    all_ones = word_emb.new_ones(qlen, klen)\n",
    "    mask_len = klen - self.mem_len\n",
    "    if mask_len > 0:\n",
    "        mask_shift_len = qlen - mask_len\n",
    "    else:\n",
    "        mask_shift_len = qlen\n",
    "    dec_attn_mask = (torch.triu(all_ones, 1+mlen)\n",
    "            + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1\n",
    "else:\n",
    "    dec_attn_mask = torch.triu(\n",
    "        word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]\n",
    "\n",
    "hids = []\n",
    "if self.attn_type == 0: # default\n",
    "    pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, \n",
    "                            dtype=word_emb.dtype)\n",
    "    if self.clamp_len > 0:\n",
    "        pos_seq.clamp_(max=self.clamp_len)\n",
    "    pos_emb = self.pos_emb(pos_seq)\n",
    "\n",
    "    core_out = self.drop(word_emb)\n",
    "    pos_emb = self.drop(pos_emb)\n",
    "\n",
    "    hids.append(core_out)\n",
    "    for i, layer in enumerate(self.layers):\n",
    "        mems_i = None if mems is None else mems[i]\n",
    "        core_out = layer(core_out, pos_emb, self.r_w_bias,\n",
    "                self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)\n",
    "        hids.append(core_out)\n",
    "elif self.attn_type == 1: # learnable\n",
    "    core_out = self.drop(word_emb)\n",
    "    hids.append(core_out)\n",
    "    for i, layer in enumerate(self.layers):\n",
    "        if self.clamp_len > 0:\n",
    "            r_emb = self.r_emb[i][-self.clamp_len :]\n",
    "            r_bias = self.r_bias[i][-self.clamp_len :]\n",
    "        else:\n",
    "            r_emb, r_bias = self.r_emb[i], self.r_bias[i]\n",
    "\n",
    "        mems_i = None if mems is None else mems[i]\n",
    "        core_out = layer(core_out, r_emb, self.r_w_bias[i],\n",
    "                r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)\n",
    "        hids.append(core_out)\n",
    "elif self.attn_type == 2: # absolute\n",
    "    pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,\n",
    "                            dtype=word_emb.dtype)\n",
    "    if self.clamp_len > 0:\n",
    "        pos_seq.clamp_(max=self.clamp_len)\n",
    "    pos_emb = self.pos_emb(pos_seq)\n",
    "\n",
    "    core_out = self.drop(word_emb + pos_emb[-qlen:])\n",
    "\n",
    "    hids.append(core_out)\n",
    "    for i, layer in enumerate(self.layers):\n",
    "        mems_i = None if mems is None else mems[i]\n",
    "        if mems_i is not None and i == 0:\n",
    "            mems_i += pos_emb[:mlen]\n",
    "        core_out = layer(core_out, dec_attn_mask=dec_attn_mask,\n",
    "                            mems=mems_i)\n",
    "        hids.append(core_out)\n",
    "elif self.attn_type == 3:\n",
    "    core_out = self.drop(word_emb)\n",
    "\n",
    "    hids.append(core_out)\n",
    "    for i, layer in enumerate(self.layers):\n",
    "        mems_i = None if mems is None else mems[i]\n",
    "        if mems_i is not None and mlen > 0:\n",
    "            cur_emb = self.r_emb[i][:-qlen]\n",
    "            cur_size = cur_emb.size(0)\n",
    "            if cur_size < mlen:\n",
    "                cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)\n",
    "                cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)\n",
    "            else:\n",
    "                cur_emb = cur_emb[-mlen:]\n",
    "            mems_i += cur_emb.view(mlen, 1, -1)\n",
    "        core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)\n",
    "\n",
    "        core_out = layer(core_out, dec_attn_mask=dec_attn_mask,\n",
    "                            mems=mems_i)\n",
    "        hids.append(core_out)\n",
    "\n",
    "core_out = self.drop(core_out)\n",
    "\n",
    "new_mems = self._update_mems(hids, mems, mlen, qlen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22fa2f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden.shape, pred_hid.shape, mem_tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c8b4d5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "target.view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76f847c",
   "metadata": {},
   "outputs": [],
   "source": [
    "target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "841ed665",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_hid.view(-1, pred_hid.size(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc5e43a",
   "metadata": {},
   "outputs": [],
   "source": [
    "src, tgt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f9ecd65",
   "metadata": {},
   "outputs": [],
   "source": [
    "tgt_len = 49\n",
    "target = tgt.contiguous()\n",
    "\n",
    "hidden, new_mems = self._forward(src, mems=mems)\n",
    "\n",
    "pred_hid = hidden[-(tgt_len+self.num_mem_tokens):]\n",
    "\n",
    "if self.num_mem_tokens not in (0, None):\n",
    "    mem_tokens, pred_hid = pred_hid[:self.num_mem_tokens].clone(), pred_hid[self.num_mem_tokens:].clone()\n",
    "loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))\n",
    "loss = loss.view(tgt_len, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "822c3698",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_hid.shape, target.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa5767f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_hid.view(-1, pred_hid.size(-1)).shape, target.view(-1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f41dcb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_seq.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4e7cabc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_inp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d4d979",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07574902",
   "metadata": {},
   "outputs": [],
   "source": [
    "word_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "991e37d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "core_out.shape, core_out.isnan().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20f52202",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.layers[0](hids[0],pos_emb,  self.r_w_bias,\n",
    "#                 self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "421835d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_mems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7908edf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "mlen, qlen, klen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4332142c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_emb.shape, pos_seq.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f7b596",
   "metadata": {},
   "outputs": [],
   "source": [
    "core_out.shape, core_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4588e65",
   "metadata": {},
   "outputs": [],
   "source": [
    "core_out[0][0] - core_out[1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36f49f78",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b25c562",
   "metadata": {},
   "outputs": [],
   "source": [
    "core_out[:, 0] - core_out[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07835a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden[0][0] - hidden[0][8]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7472a74d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a701a4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# class TransformerXL(nn.Module):\n",
    "#     def __init__(self, enc_kwargs, dec_kwargs):\n",
    "#         super().__init__()\n",
    "        \n",
    "#         self.Encoder = MemTransformerLM(**enc_kwargs)\n",
    "#         self.Decoder = MemTransformerLM(**dec_kwargs)\n",
    "\n",
    "#         self.enc_kwargs = enc_kwargs\n",
    "#         self.dec_kwargs = dec_kwargs\n",
    "\n",
    "#     def forward(self, src, tgt, mems=None):\n",
    "#         hidden, mems = self.Encoder(src, )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "472b17bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "src.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e75684",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.mem_tokens.shape, model.mem_tokens.repeat(2, 1, 1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd81ea39",
   "metadata": {},
   "outputs": [],
   "source": [
    "src.shape, tgt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5783f4ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e14d2799",
   "metadata": {},
   "outputs": [],
   "source": [
    "model(src, tgt, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "596b1aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "src.shape, src.repeat((1, 1, 2)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db167db5",
   "metadata": {},
   "outputs": [],
   "source": [
    "next(model.parameters()).device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e975a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 0.0007\n",
    "\n",
    "model_parameters = ParameterGrid({'dim': [128],\n",
    "    'tie_token_embeds': [True],\n",
    "    'return_tgt_loss': [True],\n",
    "    'enc_num_tokens': [NUM_TOKENS],\n",
    "    'depth,heads': [(2, 4)],\n",
    "    'enc_max_seq_len': [24],\n",
    "    'dec_num_tokens': [NUM_TOKENS],\n",
    "    'dec_max_seq_len': [DEC_SEQ_LEN],\n",
    "    'enc_num_memory_tokens': [2, 8, 0]})\n",
    "\n",
    "print('Total runs: ', NUM_INITS * len(model_parameters))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92978633",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_train = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_train', batch_size=BATCH_SIZE)\n",
    "gen_val = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_val', batch_size=VAL_SIZE)\n",
    "gen_test = data_loader(path=f'data{INPUT_LEN}', task_name=f'{TASK_NAME}_test', batch_size=TEST_SIZE)\n",
    "\n",
    "\n",
    "print_file = f'logs/{TASK_NAME}_{TAG}_memory_logs.txt'\n",
    "t = time.time()\n",
    "with torch.cuda.device(0):\n",
    "    for init_num in range(NUM_INITS):\n",
    "        with open(print_file, 'a') as f:\n",
    "            f.write('\\n\\nInit number ' + str(init_num)+'\\n')\n",
    "        for i, param in enumerate(list(model_parameters)):\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write('\\n\\n' + str(param)+'\\n')\n",
    "            param['enc_depth'], param['enc_heads'] = param['depth,heads']\n",
    "            param['dec_depth'], param['dec_heads'] = param['depth,heads']\n",
    "            param.pop('depth,heads')\n",
    "\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write(f'{i / len(model_parameters) * 100}%')\n",
    "            model = TransformerXL\n",
    "            model = XTransformer(**param).cuda()\n",
    "\n",
    "            model_name = f\"{TASK_NAME}{INPUT_LEN}_dim{param['dim']}d{param['enc_depth']}h{param['enc_heads']}M{param['enc_num_memory_tokens']}l{param['enc_max_seq_len']}_{TAG}_v{init_num}\"\n",
    "\n",
    "            optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "            train_validate_model(model, \n",
    "                            train_generator=gen_train, \n",
    "                            val_generator=gen_val, \n",
    "                            optim=optim, \n",
    "                            model_name=model_name, \n",
    "                            config=param,\n",
    "                            num_batches=NUM_BATCHES,\n",
    "                            generate_every=GENERATE_EVERY,\n",
    "                            print_file=print_file,\n",
    "                            tag=TAG,\n",
    "                            overfit_stop=False)\n",
    "            test_model(model, gen_test, model_name, param, TASK_NAME, tag=TAG)\n",
    "            with open(print_file, 'a') as f:\n",
    "                f.write(f'\\nTotal time: {time.time() - t}\\n')\n",
    "            t = time.time()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
