{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"../..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "import copy\n",
    "import torch\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "from grouped_batching.llama1b_grouping import (\n",
    "    wrap_model_with_armt, get_grouped_states, \n",
    "    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,\n",
    "    make_grouped_sliced_layer_from_single_layer\n",
    ")\n",
    "from grouped_batching.batching import GroupedBatcher\n",
    "from grouped_batching.executor import ArmtGroupedExecutor\n",
    "from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.set_default_device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtype = torch.bfloat16\n",
    "torch.set_default_dtype(dtype)\n",
    "torch.set_grad_enabled(False)\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
     ]
    }
   ],
   "source": [
    "source_model = AutoModelForCausalLM.from_pretrained(\n",
    "                                                    # \"meta-llama/Llama-3.2-1B\"\n",
    "                                                    # \"meta-llama/Llama-3.2-3B\"\n",
    "                                                #     \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "                                                    \"JackFram/llama-160m\"\n",
    "                                            #  , attn_implementation=\"sdpa\"\n",
    "                                            , attn_implementation=\"flash_attention_2\"\n",
    "                                             ,torch_dtype=dtype)\n",
    "source_model.eval()\n",
    "source_model.lm_head = torch.nn.Identity()\n",
    "reference_model = copy.deepcopy(source_model)\n",
    "# reference_model = source_model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = source_model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "armt_config = dict(\n",
    "    segment_size=1024,\n",
    "    num_mem_tokens=128,\n",
    "    d_mem=64,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "armt_model = wrap_model_with_armt(source_model, **armt_config)\n",
    "armt_model.to(\"cuda\")\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "### FAST LATENCY VERSION\n",
    "\n",
    "grouped_context = GroupedLayerContext()\n",
    "\n",
    "grouped_states = get_grouped_states(armt_model)\n",
    "grouped_layer = make_grouped_sliced_layer_from_single_layer(\n",
    "    grouped_context, copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states\n",
    ")\n",
    "armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)\n",
    "\n",
    "\n",
    "executor = FastGroupedArmtExecutor(\n",
    "    armt_grouped_model, \n",
    "    grouped_layer, \n",
    "    grouped_context, \n",
    "    model_config.num_hidden_layers, \n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jit compile As: [torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768])] Bs: [torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64])]\n",
      "\n",
      "// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8\n",
      "using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =\n",
      "  typename cutlass::gemm::kernel::DefaultGemmGrouped<\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor,\n",
      "    float,\n",
      "    cutlass::arch::OpClassTensorOp,\n",
      "    cutlass::arch::Sm80,\n",
      "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
      "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
      "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
      "    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,\n",
      "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
      "    3,\n",
      "    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n",
      "    cutlass::arch::OpMultiplyAdd\n",
      ">::GemmKernel;\n",
      "\n",
      "// Define named type\n",
      "struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_type :\n",
      "  public cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base { };\n",
      "\n",
      "USE_EFFICIENT_ALLOCATION\n"
     ]
    }
   ],
   "source": [
    "### ONLY FOR FAST LATENCY VERSION\n",
    "\n",
    "# compile full layers\n",
    "segments_input = torch.rand((model_config.num_hidden_layers, armt_config['segment_size'], model_config.hidden_size), device=\"cuda\", dtype=dtype)\n",
    "\n",
    "i, j = 0, model_config.num_hidden_layers\n",
    "grouped_context.start_idx = i\n",
    "grouped_context.end_idx = j\n",
    "grouped_context.is_full = True\n",
    "\n",
    "ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "grouped_layer.generate_mode = True\n",
    "armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "\n",
    "# del ao\n",
    "# del segments_input\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 512, 768])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "segments_input[i:j].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # %%time\n",
    "\n",
    "\n",
    "# with torch.profiler.profile(\n",
    "#     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],\n",
    "#     record_shapes=True,\n",
    "#     profile_memory=True,\n",
    "# ) as prof:\n",
    "#     # for _ in range(1):\n",
    "#     #     for __ in range(model_config.num_hidden_layers):\n",
    "#         # if __ != 0:\n",
    "#         ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "#         grouped_layer.generate_mode = True\n",
    "#         _ = armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "#         # if __ != 0:\n",
    "#         update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "\n",
    "# torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 281 ms, sys: 0 ns, total: 281 ms\n",
      "Wall time: 281 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "num_retries = 5\n",
    "for _ in range(num_retries):\n",
    "    for __ in range(model_config.num_hidden_layers):\n",
    "        # if __ != 0:\n",
    "        ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "        grouped_layer.generate_mode = True\n",
    "        _ = armt_grouped_model.memory_cell.model.model(\n",
    "            inputs_embeds=segments_input[i:j], use_cache=False\n",
    "        )\n",
    "        # if __ != 0:\n",
    "        update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config.num_hidden_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.683333333333334"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Then we have full load, group size is equal to \n",
    "# divide by num_retries and num_layers (equal to full load) to get average time per segment\n",
    "# this number is used for ideal scaling line in paper's table\n",
    "281/num_retries/model_config.num_hidden_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kernel_armt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
