{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys \n",
    "\n",
    "sys.path.append('..')\n",
    "from omegaconf import OmegaConf\n",
    "from pprint import pprint\n",
    "from dacite import from_dict\n",
    "from dacite import Config as DaciteConfig\n",
    "import torch\n",
    "\n",
    "from xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig\n",
    "\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# xLSTM Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "xlstm_cfg = f\"\"\" \n",
    "mlstm_block:\n",
    "  mlstm:\n",
    "    conv1d_kernel_size: 4\n",
    "    qkv_proj_blocksize: 4\n",
    "    num_heads: 4\n",
    "slstm_block:\n",
    "  slstm:\n",
    "    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'} #! only vanilla here works\n",
    "    num_heads: 4\n",
    "    conv1d_kernel_size: 4\n",
    "    bias_init: powerlaw_blockdependent\n",
    "  feedforward:\n",
    "    proj_factor: 1.3\n",
    "    act_fn: gelu\n",
    "context_length: 256\n",
    "num_blocks: 7\n",
    "embedding_dim: 128\n",
    "slstm_at: [1] #[1] # for [] it also works, so if no sLSTM is in the stack\n",
    "\"\"\"\n",
    "cfg = OmegaConf.create(xlstm_cfg)\n",
    "cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n",
    "xlstm_stack = xLSTMBlockStack(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,\n",
      "                                                                          round_proj_up_dim_up=True,\n",
      "                                                                          round_proj_up_to_multiple_of=64,\n",
      "                                                                          _proj_up_dim=256,\n",
      "                                                                          conv1d_kernel_size=4,\n",
      "                                                                          qkv_proj_blocksize=4,\n",
      "                                                                          num_heads=4,\n",
      "                                                                          embedding_dim=128,\n",
      "                                                                          bias=False,\n",
      "                                                                          dropout=0.0,\n",
      "                                                                          context_length=256,\n",
      "                                                                          _num_blocks=7,\n",
      "                                                                          _inner_embedding_dim=256)),\n",
      "                      slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(hidden_size=128,\n",
      "                                                                          num_heads=4,\n",
      "                                                                          num_states=4,\n",
      "                                                                          backend='vanilla',\n",
      "                                                                          function='slstm',\n",
      "                                                                          bias_init='powerlaw_blockdependent',\n",
      "                                                                          recurrent_weight_init='zeros',\n",
      "                                                                          _block_idx=0,\n",
      "                                                                          _num_blocks=7,\n",
      "                                                                          num_gates=4,\n",
      "                                                                          gradient_recurrent_cut=False,\n",
      "                                                                          gradient_recurrent_clipval=None,\n",
      "                                                                          forward_clipval=None,\n",
      "                                                                          batch_size=8,\n",
      "                                                                          input_shape='BSGNH',\n",
      "                                                                          internal_input_shape='SBNGH',\n",
      "                                                                          output_shape='BNSH',\n",
      "                                                                          constants={},\n",
      "                                                                          dtype='bfloat16',\n",
      "                                                                          dtype_b='float32',\n",
      "                                                                          dtype_r='bfloat16',\n",
      "                                                                          dtype_w='bfloat16',\n",
      "                                                                          dtype_g='bfloat16',\n",
      "                                                                          dtype_s='bfloat16',\n",
      "                                                                          dtype_a='float32',\n",
      "                                                                          enable_automatic_mixed_precision=True,\n",
      "                                                                          initial_val=0.0,\n",
      "                                                                          embedding_dim=128,\n",
      "                                                                          conv1d_kernel_size=4,\n",
      "                                                                          group_norm_weight=True,\n",
      "                                                                          dropout=0.0),\n",
      "                                                   feedforward=FeedForwardConfig(proj_factor=1.3,\n",
      "                                                                                 round_proj_up_dim_up=True,\n",
      "                                                                                 round_proj_up_to_multiple_of=64,\n",
      "                                                                                 _proj_up_dim=0,\n",
      "                                                                                 act_fn='gelu',\n",
      "                                                                                 embedding_dim=-1,\n",
      "                                                                                 dropout=0.0,\n",
      "                                                                                 bias=False,\n",
      "                                                                                 ff_type='ffn_gated',\n",
      "                                                                                 _num_blocks=1),\n",
      "                                                   _num_blocks=7,\n",
      "                                                   _block_idx=0),\n",
      "                      context_length=256,\n",
      "                      num_blocks=7,\n",
      "                      embedding_dim=128,\n",
      "                      add_post_blocks_norm=True,\n",
      "                      bias=False,\n",
      "                      dropout=0.0,\n",
      "                      slstm_at=[1],\n",
      "                      _block_map='0,1,0,0,0,0,0')\n"
     ]
    }
   ],
   "source": [
    "pprint(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xLSTMBlockStack(\n",
       "  (blocks): ModuleList(\n",
       "    (0): mLSTMBlock(\n",
       "      (xlstm_norm): LayerNorm()\n",
       "      (xlstm): mLSTMLayer(\n",
       "        (proj_up): Linear(in_features=128, out_features=512, bias=False)\n",
       "        (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (conv1d): CausalConv1d(\n",
       "          (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)\n",
       "        )\n",
       "        (conv_act_fn): SiLU()\n",
       "        (mlstm_cell): mLSTMCell(\n",
       "          (igate): Linear(in_features=768, out_features=4, bias=True)\n",
       "          (fgate): Linear(in_features=768, out_features=4, bias=True)\n",
       "          (outnorm): MultiHeadLayerNorm()\n",
       "        )\n",
       "        (ogate_act_fn): SiLU()\n",
       "        (proj_down): Linear(in_features=256, out_features=128, bias=False)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (1): sLSTMBlock(\n",
       "      (xlstm_norm): LayerNorm()\n",
       "      (xlstm): sLSTMLayer(\n",
       "        (conv1d): CausalConv1d(\n",
       "          (conv): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)\n",
       "        )\n",
       "        (conv_act_fn): SiLU()\n",
       "        (fgate): LinearHeadwiseExpand(in_features=128, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (igate): LinearHeadwiseExpand(in_features=128, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (zgate): LinearHeadwiseExpand(in_features=128, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (ogate): LinearHeadwiseExpand(in_features=128, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (slstm_cell): sLSTMCell_vanilla(function=slstm, hidden_size=128, num_heads=4)\n",
       "        (group_norm): MultiHeadLayerNorm()\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (ffn_norm): LayerNorm()\n",
       "      (ffn): GatedFeedForward(\n",
       "        (proj_up): Linear(in_features=128, out_features=384, bias=False)\n",
       "        (proj_down): Linear(in_features=192, out_features=128, bias=False)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (2-6): 5 x mLSTMBlock(\n",
       "      (xlstm_norm): LayerNorm()\n",
       "      (xlstm): mLSTMLayer(\n",
       "        (proj_up): Linear(in_features=128, out_features=512, bias=False)\n",
       "        (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (conv1d): CausalConv1d(\n",
       "          (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)\n",
       "        )\n",
       "        (conv_act_fn): SiLU()\n",
       "        (mlstm_cell): mLSTMCell(\n",
       "          (igate): Linear(in_features=768, out_features=4, bias=True)\n",
       "          (fgate): Linear(in_features=768, out_features=4, bias=True)\n",
       "          (outnorm): MultiHeadLayerNorm()\n",
       "        )\n",
       "        (ogate_act_fn): SiLU()\n",
       "        (proj_down): Linear(in_features=256, out_features=128, bias=False)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (post_blocks_norm): LayerNorm()\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xlstm_stack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "x = torch.randn(4, 256, 128).to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "xlstm_stack = xlstm_stack.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "y = xlstm_stack(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 256, 128])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Readme example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from omegaconf import OmegaConf\n",
    "from dacite import from_dict\n",
    "from dacite import Config as DaciteConfig\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig\n",
    "\n",
    "xlstm_cfg = f\"\"\" \n",
    "mlstm_block:\n",
    "  mlstm:\n",
    "    conv1d_kernel_size: 4\n",
    "    qkv_proj_blocksize: 4\n",
    "    num_heads: 4\n",
    "slstm_block:\n",
    "  slstm:\n",
    "    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}\n",
    "    num_heads: 4\n",
    "    conv1d_kernel_size: 4\n",
    "    bias_init: powerlaw_blockdependent\n",
    "  feedforward:\n",
    "    proj_factor: 1.3\n",
    "    act_fn: gelu\n",
    "context_length: 256\n",
    "num_blocks: 7\n",
    "embedding_dim: 128\n",
    "slstm_at: [1]\n",
    "\"\"\"\n",
    "cfg = OmegaConf.create(xlstm_cfg)\n",
    "cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n",
    "xlstm_stack = xLSTMBlockStack(cfg)\n",
    "\n",
    "x = torch.randn(4, 256, 128).to(device=device)\n",
    "xlstm_stack = xlstm_stack.to(device=device)\n",
    "y = xlstm_stack(x)\n",
    "y.shape == x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from xlstm import (\n",
    "    xLSTMBlockStack,\n",
    "    xLSTMBlockStackConfig,\n",
    "    mLSTMBlockConfig,\n",
    "    mLSTMLayerConfig,\n",
    "    sLSTMBlockConfig,\n",
    "    sLSTMLayerConfig,\n",
    "    FeedForwardConfig,\n",
    ")\n",
    "\n",
    "cfg = xLSTMBlockStackConfig(\n",
    "    mlstm_block=mLSTMBlockConfig(\n",
    "        mlstm=mLSTMLayerConfig(\n",
    "            conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4\n",
    "        )\n",
    "    ),\n",
    "    slstm_block=sLSTMBlockConfig(\n",
    "        slstm=sLSTMLayerConfig(\n",
    "            backend=\"cuda\" if torch.cuda.isavailable() else \"vanilla\",\n",
    "            num_heads=4,\n",
    "            conv1d_kernel_size=4,\n",
    "            bias_init=\"powerlaw_blockdependent\",\n",
    "        ),\n",
    "        feedforward=FeedForwardConfig(proj_factor=1.3, act_fn=\"gelu\"),\n",
    "    ),\n",
    "    context_length=256,\n",
    "    num_blocks=7,\n",
    "    embedding_dim=128,\n",
    "    slstm_at=[1],\n",
    "\n",
    ")\n",
    "\n",
    "xlstm_stack = xLSTMBlockStack(cfg)\n",
    "\n",
    "x = torch.randn(4, 256, 128).to(device=device)\n",
    "xlstm_stack = xlstm_stack.to(device=device)\n",
    "y = xlstm_stack(x)\n",
    "y.shape == x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from dacite import from_dict\n",
    "from dacite import Config as DaciteConfig\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "from xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig\n",
    "\n",
    "xlstm_cfg = f\"\"\" \n",
    "mlstm_block:\n",
    "  mlstm:\n",
    "    conv1d_kernel_size: 4\n",
    "    qkv_proj_blocksize: 4\n",
    "    num_heads: 4\n",
    "slstm_block:\n",
    "  slstm:\n",
    "    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}\n",
    "    num_heads: 4\n",
    "    conv1d_kernel_size: 4\n",
    "    bias_init: powerlaw_blockdependent\n",
    "  feedforward:\n",
    "    proj_factor: 1.3\n",
    "    act_fn: gelu\n",
    "context_length: 256\n",
    "num_blocks: 7\n",
    "embedding_dim: 128\n",
    "slstm_at: [] #[1]\n",
    "\"\"\"\n",
    "cfg = OmegaConf.create(xlstm_cfg)\n",
    "cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n",
    "xlstm_stack = xLSTMBlockStack(cfg)\n",
    "\n",
    "x = torch.randn(4, 256, 128).to(device=device)\n",
    "xlstm_stack = xlstm_stack.to(device=device)\n",
    "y = xlstm_stack(x)\n",
    "y.shape == x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from dacite import from_dict\n",
    "from dacite import Config as DaciteConfig\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig\n",
    "\n",
    "xlstm_cfg = f\"\"\" \n",
    "vocab_size: 50304\n",
    "mlstm_block:\n",
    "  mlstm:\n",
    "    conv1d_kernel_size: 4\n",
    "    qkv_proj_blocksize: 4\n",
    "    num_heads: 4\n",
    "slstm_block:\n",
    "  slstm:\n",
    "    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}\n",
    "    num_heads: 4\n",
    "    conv1d_kernel_size: 4\n",
    "    bias_init: powerlaw_blockdependent\n",
    "  feedforward:\n",
    "    proj_factor: 1.3\n",
    "    act_fn: gelu\n",
    "context_length: 256\n",
    "num_blocks: 7\n",
    "embedding_dim: 128\n",
    "slstm_at: [] #[1]\n",
    "\"\"\"\n",
    "cfg = OmegaConf.create(xlstm_cfg)\n",
    "cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))\n",
    "xlstm_stack = xLSTMLMModel(cfg)\n",
    "\n",
    "x = torch.randint(0, 50304, size=(4, 256)).to(device=device)\n",
    "xlstm_stack = xlstm_stack.to(device)\n",
    "y = xlstm_stack(x)\n",
    "y.shape[1:] == (256, 50304)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xlstmpt220cu121",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
