{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "metadata": {},
    "notebookRunGroups": {
     "groupValue": "1"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import torch\n",
    "\n",
    "from xlstm.blocks.slstm.cell import sLSTMCell, sLSTMCellConfig\n",
    "from xlstm.blocks.slstm.block import sLSTMBlock, sLSTMBlockConfig\n",
    "from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig\n",
    "from xlstm.components.feedforward import FeedForwardConfig\n",
    "\n",
    "backend = \"vanilla\" #\"cuda\" if torch.cuda.is_available() else \"vanilla\"\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "metadata": {},
    "notebookRunGroups": {
     "groupValue": "1"
    }
   },
   "outputs": [],
   "source": [
    "cell = sLSTMCell(sLSTMCellConfig(hidden_size=128, num_heads=2, backend=backend, dtype=\"bfloat16\", function=\"slstm\", enable_automatic_mixed_precision=False)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "metadata": {},
    "notebookRunGroups": {
     "groupValue": "1"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 2, 4, 64]) torch.Size([4, 4, 128])\n"
     ]
    }
   ],
   "source": [
    "res = cell(torch.zeros([4, 4, 4*128], dtype=torch.bfloat16).to(device))\n",
    "print(res[0].shape, res[1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "block = sLSTMBlock(sLSTMBlockConfig(slstm=sLSTMLayerConfig(embedding_dim=32,\n",
    "    num_heads=4, backend=backend, function=\"lstm\", bias_init=\"standard\", recurrent_weight_init=\"standard\",\n",
    "    _block_idx=2, _num_blocks=4, enable_automatic_mixed_precision=False), feedforward=FeedForwardConfig())).to(device).to(dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "inp = torch.ones((3, 4, 32), dtype=torch.bfloat16).to(device)\n",
    "inp.requires_grad = True\n",
    "res = block(inp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 4, 32])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0\n",
      "tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0\n",
      "tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0\n",
      "tensor(1.1250, device='cuda:0', dtype=torch.bfloat16) != 0\n",
      "tensor(33.7500, device='cuda:0', dtype=torch.bfloat16) != 0\n"
     ]
    }
   ],
   "source": [
    "res[1, 1].sum().backward()\n",
    "# check for causality, batch interconnect\n",
    "print(inp.grad[2, 1].sum(), \"== 0\")\n",
    "print(inp.grad[0, 1].sum(), \"== 0\")\n",
    "print(inp.grad[1, 2].sum(), \"== 0\")\n",
    "print(inp.grad[1, 0].sum(), \"!= 0\")\n",
    "print(inp.grad[1, 1].sum(), \"!= 0\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "from xlstm.xlstm_block_stack import xLSTMBlockStackConfig, xLSTMBlockStack, mLSTMBlockConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "bs = xLSTMBlockStack(\n",
    "    xLSTMBlockStackConfig(\n",
    "        slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(backend=backend)),\n",
    "        slstm_at=\"all\",\n",
    "        num_blocks=48,\n",
    "        embedding_dim=1024,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xLSTMBlockStack(\n",
       "  (blocks): ModuleList(\n",
       "    (0-47): 48 x sLSTMBlock(\n",
       "      (xlstm_norm): LayerNorm()\n",
       "      (xlstm): sLSTMLayer(\n",
       "        (conv1d): CausalConv1d(\n",
       "          (conv): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)\n",
       "        )\n",
       "        (conv_act_fn): SiLU()\n",
       "        (fgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (igate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (zgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (ogate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (slstm_cell): sLSTMCell_vanilla(function=slstm, hidden_size=1024, num_heads=4)\n",
       "        (group_norm): MultiHeadLayerNorm()\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (ffn_norm): LayerNorm()\n",
       "      (ffn): GatedFeedForward(\n",
       "        (proj_up): Linear(in_features=1024, out_features=2688, bias=False)\n",
       "        (proj_down): Linear(in_features=1344, out_features=1024, bias=False)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (post_blocks_norm): LayerNorm()\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "bs = xLSTMBlockStack(xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(), context_length=2048, num_blocks=48, embedding_dim=1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xLSTMBlockStack(\n",
       "  (blocks): ModuleList(\n",
       "    (0-47): 48 x mLSTMBlock(\n",
       "      (xlstm_norm): LayerNorm()\n",
       "      (xlstm): mLSTMLayer(\n",
       "        (proj_up): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "        (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )\n",
       "        (conv1d): CausalConv1d(\n",
       "          (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)\n",
       "        )\n",
       "        (conv_act_fn): SiLU()\n",
       "        (mlstm_cell): mLSTMCell(\n",
       "          (igate): Linear(in_features=6144, out_features=4, bias=True)\n",
       "          (fgate): Linear(in_features=6144, out_features=4, bias=True)\n",
       "          (outnorm): MultiHeadLayerNorm()\n",
       "        )\n",
       "        (ogate_act_fn): SiLU()\n",
       "        (proj_down): Linear(in_features=2048, out_features=1024, bias=False)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (post_blocks_norm): LayerNorm()\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xlstmpt220cu121",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
