{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlstm_kernels.torch.backend_module import mLSTMBackendConfig, mLSTMBackend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we use the mLSTMexp TFLA kernel\n",
    "# we also configure to use the triton step kernel for inference\n",
    "mlstm_backend_config = mLSTMBackendConfig(\n",
    "    chunkwise_kernel=\"chunkwise--triton_xl_chunk\",\n",
    "    sequence_kernel=\"native_sequence__triton\",\n",
    "    step_kernel=\"triton\",\n",
    "    chunk_size=256,\n",
    "    return_last_states=False,\n",
    ")\n",
    "\n",
    "mlstm_backend = mLSTMBackend(mlstm_backend_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run the kernel\n",
    "DEVICE = torch.device(\"cuda\")\n",
    "DTYPE = torch.bfloat16\n",
    "B = 2\n",
    "S = 512\n",
    "DHQK = 128\n",
    "DHHV = 256\n",
    "NH = 4\n",
    "\n",
    "# create input tensors\n",
    "torch.manual_seed(1)\n",
    "matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\n",
    "matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\n",
    "matV = torch.randn((B, NH, S, DHHV), dtype=DTYPE, device=DEVICE)\n",
    "vecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\n",
    "vecF = 3.0 + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matH1 = mlstm_backend(q=matQ, k=matK, v=matV, i=vecI, f=vecF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# directly import mLSTMexp TFLA kernel\n",
    "from mlstm_kernels.torch.chunkwise.triton_xl_chunk import mlstm_chunkwise__xl_chunk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matH2 = mlstm_chunkwise__xl_chunk(\n",
    "    q=matQ, k=matK, v=matV, i=vecI, f=vecF, return_last_states=False, chunk_size=256\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(matH1, matH2, atol=1e-5, rtol=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
