{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "import torch\n",
    "\n",
    "from mlstm_kernels.utils.test.checks import verify_output\n",
    "from mlstm_kernels.torch.utils import to_numpy\n",
    "from tests.torch.losses_tests import loss_layernorm_offset_quadratic\n",
    "\n",
    "torch.set_printoptions(linewidth=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2\n",
    "from mlstm_kernels.torch.parallel._legacy_native_siging.sig_ingate import (\n",
    "    mlstm_siging_parallel,\n",
    ")\n",
    "from mlstm_kernels.torch.parallel.native_siging import (\n",
    "    mlstm_siging_parallel__native_autograd,\n",
    "    mlstm_siging_parallel__native_custbw,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "B = 1\n",
    "NH = 1\n",
    "S = 128\n",
    "DHQK = 32\n",
    "DHHV = 64\n",
    "device = torch.device(\"cuda:0\")\n",
    "dtype = torch.float32\n",
    "\n",
    "vecI_offset = 0.0\n",
    "vecF_offset = 3.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "matQ = torch.randn((B, NH, S, DHQK), dtype=torch.float32, device=device)\n",
    "matK = torch.randn((B, NH, S, DHQK), dtype=torch.float32, device=device)\n",
    "matV = torch.randn((B, NH, S, DHHV), dtype=torch.float32, device=device)\n",
    "vecI = vecI_offset + torch.randn((B, NH, S), dtype=torch.float32, device=device)\n",
    "vecF = vecF_offset + torch.randn((B, NH, S), dtype=torch.float32, device=device)\n",
    "\n",
    "baseline_dtype = dtype\n",
    "matQ_baseline = matQ.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)\n",
    "matK_baseline = matK.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)\n",
    "matV_baseline = matV.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)\n",
    "vecI_baseline = vecI.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)\n",
    "vecF_baseline = vecF.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)\n",
    "\n",
    "target_dtype = dtype\n",
    "matQ_target = matQ.clone().to(dtype=target_dtype).detach().requires_grad_(True)\n",
    "matK_target = matK.clone().to(dtype=target_dtype).detach().requires_grad_(True)\n",
    "matV_target = matV.clone().to(dtype=target_dtype).detach().requires_grad_(True)\n",
    "vecI_target = vecI.clone().to(dtype=target_dtype).detach().requires_grad_(True)\n",
    "vecF_target = vecF.clone().to(dtype=target_dtype).detach().requires_grad_(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "matH_bl = mlstm_siging_parallel(\n",
    "    matQ_baseline,\n",
    "    matK_baseline,\n",
    "    matV_baseline,\n",
    "    vecI_baseline.unsqueeze(-1),\n",
    "    vecF_baseline.unsqueeze(-1),\n",
    "    normalize_sqrt_d=True,\n",
    ")\n",
    "loss_layernorm_offset_quadratic(matH_bl).backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vecI_baseline.unsqueeze(-1).transpose(-2, -1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matH_tgt = (\n",
    "    mlstm_siging_parallel__native_custbw(  # mlstm_siging_parallel__native_autograd(\n",
    "        matQ_target,\n",
    "        matK_target,\n",
    "        matV_target,\n",
    "        vecI_target,\n",
    "        vecF_target,\n",
    "        stable_fgate=False,\n",
    "    )\n",
    ")\n",
    "loss_layernorm_offset_quadratic(matH_tgt).backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"matH_stable_fgate\", to_numpy(matH_bl), to_numpy(matH_tgt), atol=1e-5, rtol=1e-5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"matQ.grad\",\n",
    "    to_numpy(matQ_baseline.grad),\n",
    "    to_numpy(matQ_target.grad),\n",
    "    atol=1e-5,\n",
    "    rtol=1e-5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"matK.grad\",\n",
    "    to_numpy(matK_baseline.grad),\n",
    "    to_numpy(matK_target.grad),\n",
    "    atol=1e-5,\n",
    "    rtol=1e-5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"matV.grad\",\n",
    "    to_numpy(matV_baseline.grad),\n",
    "    to_numpy(matV_target.grad),\n",
    "    atol=1e-5,\n",
    "    rtol=1e-5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"vecI.grad\",\n",
    "    to_numpy(vecI_baseline.grad),\n",
    "    to_numpy(vecI_target.grad),\n",
    "    atol=1e-5,\n",
    "    rtol=1e-5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = verify_output(\n",
    "    \"vecF.grad\",\n",
    "    to_numpy(vecF_baseline.grad),\n",
    "    to_numpy(vecF_target.grad),\n",
    "    atol=1e-5,\n",
    "    rtol=1e-5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xlstmpt240cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
