{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s\",\n",
    ")\n",
    "\n",
    "import torch\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from mlstm_kernels.baselines.lightning_attention.lightning_attn2 import lightning_attn2\n",
    "from mlstm_kernels.baselines.lightning_attention.utils import _build_slope_tensor\n",
    "\n",
    "from dacite import from_dict\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig\n",
    "from mlstm_kernels.utils.benchmark.run_benchmark import run_benchmarks\n",
    "from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import (\n",
    "    create_training_kernel_benchmark,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### quick test if it is runnable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = 4\n",
    "h = 8\n",
    "n = 512\n",
    "d = 128\n",
    "dtype = torch.bfloat16\n",
    "device = torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()\n",
    "k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()\n",
    "v = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()\n",
    "s = _build_slope_tensor(h).to(q.device).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = lightning_attn2(q, k, v, s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### benchmark\n",
    "\n",
    "Note: lightning attention does not support large head dimensions. Get a Out of shared memory error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "S = 8192\n",
    "DHQK = 128  # 64 #256 #128 #256  # *2\n",
    "DHHV = 128  # 64 #256 #128 #512  # *2\n",
    "NH = 32  # 64 #16 #32 #8\n",
    "B = 2\n",
    "D = NH * DHHV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg_yaml = f\"\"\"\n",
    "vary_type: grid\n",
    "vary_params: {dict()}\n",
    "fixed_params: \n",
    "  batch_size: {B}\n",
    "  sequence_length: {S}\n",
    "  num_heads: {NH}\n",
    "  head_dim_qk: {DHQK}\n",
    "  head_dim_v: {DHHV}\n",
    "  warmup: 5\n",
    "  rep: 10\n",
    "\n",
    "kernel_specs:\n",
    "  # - kernel_name: \"chunkwise--triton_limit_chunk\"\n",
    "  #   fwbw: False\n",
    "  #   dtype: bfloat16\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 64\n",
    "  # - kernel_name: \"chunkwise--triton_limit_chunk\"\n",
    "  #   fwbw: True\n",
    "  #   dtype: bfloat16\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 64\n",
    "  # - kernel_name: \"chunkwise--triton_xl_chunk\"\n",
    "  #   fwbw: False\n",
    "  #   dtype: bfloat16\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 128\n",
    "  # - kernel_name: \"chunkwise--triton_xl_chunk\"\n",
    "  #   fwbw: True\n",
    "  #   dtype: bfloat16\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 128\n",
    "  # - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "  #   fwbw: False\n",
    "  #   dtype: bfloat16\n",
    "  #   use_torch_compile: False\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 128\n",
    "  #     normalize: False\n",
    "  # - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "  #   fwbw: True\n",
    "  #   dtype: bfloat16\n",
    "  #   use_torch_compile: False\n",
    "  #   additional_params:\n",
    "  #     chunk_size: 128\n",
    "  #     normalize: False\n",
    "  - kernel_name: \"lightning_attn2\"\n",
    "    fwbw: False\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "  - kernel_name: \"lightning_attn2\"\n",
    "    fwbw: True\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "    fwbw: False\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "    additional_params:\n",
    "      chunk_size: 256\n",
    "      normalize: False\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "    fwbw: True\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "    additional_params:\n",
    "      chunk_size: 256\n",
    "      normalize: False\n",
    "    \n",
    "\n",
    "\n",
    "  \n",
    "benchmark_name: \"quick_kernel_benchmark\"\n",
    "\"\"\"\n",
    "cfg_baseline = from_dict(\n",
    "    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df = run_benchmarks(\n",
    "    cfg_baseline,\n",
    "    benchmark_creator=create_training_kernel_benchmark,\n",
    "    run_garbage_collection=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df.filter(regex=\"(R|P)--.*\", axis=1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df.filter(regex=\"(P|M)--.*\", axis=1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124_beck",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
