{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s\",\n",
    ")\n",
    "\n",
    "from dacite import from_dict\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig\n",
    "from mlstm_kernels.utils.benchmark.run_benchmark import run_benchmarks\n",
    "from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import (\n",
    "    create_training_kernel_benchmark,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_length_limits = [9, 17]\n",
    "sequence_lengths = list(map(lambda i: 1 << i, range(*sequence_length_limits)))\n",
    "batch_sizes = list(\n",
    "    map(\n",
    "        lambda i: 1 << i,\n",
    "        reversed(range(sequence_length_limits[1] - sequence_length_limits[0])),\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_lengths, batch_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "S = 8192\n",
    "DHQK = 64  # 256  # *2\n",
    "DHHV = 128  # 512  # *2\n",
    "NH = 8  # 8\n",
    "B = 2\n",
    "D = NH * DHHV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg_yaml = f\"\"\"\n",
    "vary_type: grid\n",
    "vary_params: {dict()}\n",
    "fixed_params: \n",
    "  batch_size: {B}\n",
    "  sequence_length: {S}\n",
    "  num_heads: {NH}\n",
    "  head_dim_qk: {DHQK}\n",
    "  head_dim_v: {DHHV}\n",
    "  warmup: 5\n",
    "  rep: 10\n",
    "\n",
    "kernel_specs:\n",
    "  - kernel_name: \"chunkwise--triton_limit_chunk\"\n",
    "    fwbw: False\n",
    "    dtype: bfloat16\n",
    "    additional_params:\n",
    "      chunk_size: 64\n",
    "  - kernel_name: \"chunkwise--triton_limit_chunk\"\n",
    "    fwbw: True\n",
    "    dtype: bfloat16\n",
    "    additional_params:\n",
    "      chunk_size: 64\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk\"\n",
    "    fwbw: False\n",
    "    dtype: bfloat16\n",
    "    additional_params:\n",
    "      chunk_size: 128\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk\"\n",
    "    fwbw: True\n",
    "    dtype: bfloat16\n",
    "    additional_params:\n",
    "      chunk_size: 128\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "    fwbw: False\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "    additional_params:\n",
    "      chunk_size: 128\n",
    "      normalize: False\n",
    "  - kernel_name: \"chunkwise--triton_xl_chunk_siging\"\n",
    "    fwbw: True\n",
    "    dtype: bfloat16\n",
    "    use_torch_compile: False\n",
    "    additional_params:\n",
    "      chunk_size: 128\n",
    "      normalize: False\n",
    "\n",
    "\n",
    "benchmark_name: \"quick_kernel_benchmark\"\n",
    "\"\"\"\n",
    "cfg_baseline = from_dict(\n",
    "    data_class=BenchmarkConfig, data=OmegaConf.to_container(OmegaConf.create(cfg_yaml))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df = run_benchmarks(\n",
    "    cfg_baseline,\n",
    "    benchmark_creator=create_training_kernel_benchmark,\n",
    "    run_garbage_collection=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df.filter(regex=\"R--.*\", axis=1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df.filter(regex=\"M--.*\", axis=1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlstmpt251cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
