{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys \n",
    "\n",
    "sys.path.append('..')\n",
    "from omegaconf import OmegaConf\n",
    "from pprint import pprint\n",
    "from dacite import from_dict\n",
    "from dacite import Config as DaciteConfig\n",
    "import torch\n",
    "\n",
    "from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create new model to load the checkpoint into\n",
    "xlstm_cfg = \"\"\" \n",
    "vocab_size: 50304\n",
    "context_length: 2048      \n",
    "num_blocks: 24 #!\n",
    "embedding_dim: 768 #!\n",
    "tie_weights: false\n",
    "weight_decay_on_embedding: false\n",
    "mlstm_block:\n",
    "  mlstm:\n",
    "    conv1d_kernel_size: 4\n",
    "    qkv_proj_blocksize: 4\n",
    "    num_heads: 4\n",
    "\"\"\"\n",
    "cfg = OmegaConf.create(xlstm_cfg)\n",
    "cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n",
    "model_new = xLSTMLMModel(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xLSTMLMModelConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,\n",
      "                                                                       round_proj_up_dim_up=True,\n",
      "                                                                       round_proj_up_to_multiple_of=64,\n",
      "                                                                       _proj_up_dim=1536,\n",
      "                                                                       conv1d_kernel_size=4,\n",
      "                                                                       qkv_proj_blocksize=4,\n",
      "                                                                       num_heads=4,\n",
      "                                                                       embedding_dim=768,\n",
      "                                                                       bias=False,\n",
      "                                                                       dropout=0.0,\n",
      "                                                                       context_length=2048,\n",
      "                                                                       _num_blocks=24,\n",
      "                                                                       _inner_embedding_dim=1536)),\n",
      "                   slstm_block=None,\n",
      "                   context_length=2048,\n",
      "                   num_blocks=24,\n",
      "                   embedding_dim=768,\n",
      "                   add_post_blocks_norm=True,\n",
      "                   bias=False,\n",
      "                   dropout=0.0,\n",
      "                   slstm_at=[],\n",
      "                   _block_map='0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0',\n",
      "                   vocab_size=50304,\n",
      "                   tie_weights=False,\n",
      "                   weight_decay_on_embedding=False,\n",
      "                   add_embedding_dropout=False)\n"
     ]
    }
   ],
   "source": [
    "pprint(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_in = torch.randint(0, 50304, (1, 32)).to(device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_new = model_new.to(device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xLSTMLMModel(\n",
       "  (xlstm_block_stack): xLSTMBlockStack(\n",
       "    (blocks): ModuleList(\n",
       "      (0-23): 24 x mLSTMBlock(\n",
       "        (xlstm_norm): LayerNorm()\n",
       "        (xlstm): mLSTMLayer(\n",
       "          (proj_up): Linear(in_features=768, out_features=3072, bias=False)\n",
       "          (q_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "          (k_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "          (v_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "          (conv1d): CausalConv1d(\n",
       "            (conv): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)\n",
       "          )\n",
       "          (conv_act_fn): SiLU()\n",
       "          (mlstm_cell): mLSTMCell(\n",
       "            (igate): Linear(in_features=4608, out_features=4, bias=True)\n",
       "            (fgate): Linear(in_features=4608, out_features=4, bias=True)\n",
       "            (outnorm): MultiHeadLayerNorm()\n",
       "          )\n",
       "          (ogate_act_fn): SiLU()\n",
       "          (proj_down): Linear(in_features=1536, out_features=768, bias=False)\n",
       "          (dropout): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (post_blocks_norm): LayerNorm()\n",
       "  )\n",
       "  (token_embedding): Embedding(50304, 768)\n",
       "  (emb_dropout): Identity()\n",
       "  (lm_head): Linear(in_features=768, out_features=50304, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_new = model_new(x_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 32, 50304])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_new.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1])\n"
     ]
    }
   ],
   "source": [
    "y_new_step = []\n",
    "state = None\n",
    "for x in x_in.split(1, dim=1):\n",
    "    y, state = model_new.step(x, state)\n",
    "    y_new_step.append(y)\n",
    "y_new_step = torch.cat(y_new_step, dim=1)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 32, 50304])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_new_step.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 5.9605e-08,  4.1723e-07,  5.9605e-08,  ...,  1.0431e-07,\n",
       "           1.1921e-07,  3.7253e-07],\n",
       "         [-2.3842e-07,  8.9407e-08,  5.9605e-08,  ..., -2.8312e-07,\n",
       "           1.7881e-07,  3.5763e-07],\n",
       "         [ 2.0862e-07, -2.9802e-08,  4.7684e-07,  ...,  3.5763e-07,\n",
       "           3.5763e-07,  5.9605e-08],\n",
       "         ...,\n",
       "         [ 5.9605e-08,  5.9605e-07, -5.9605e-08,  ..., -8.9407e-08,\n",
       "           5.9605e-08, -2.9802e-07],\n",
       "         [ 2.3842e-07, -1.9372e-07, -1.3411e-07,  ..., -3.2783e-07,\n",
       "           4.4703e-07, -5.9605e-08],\n",
       "         [-1.1921e-07, -2.3842e-07,  6.2585e-07,  ..., -5.0664e-07,\n",
       "          -1.7881e-07, -2.2352e-07]]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_new - y_new_step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(y_new, y_new_step, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xlstmpt220cu121",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
