{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "import copy\n",
    "import torch\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "from grouped_batching.llama1b_grouping import (\n",
    "    wrap_model_with_armt, get_grouped_states, \n",
    "    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,\n",
    "    make_grouped_sliced_layer_from_single_layer\n",
    ")\n",
    "from grouped_batching.batching import GroupedBatcher\n",
    "from grouped_batching.executor import ArmtGroupedExecutor\n",
    "from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.set_default_device(\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtype = torch.bfloat16\n",
    "torch.set_default_dtype(dtype)\n",
    "torch.set_grad_enabled(False)\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
     ]
    }
   ],
   "source": [
    "source_model = AutoModelForCausalLM.from_pretrained(\n",
    "                                                    \"meta-llama/Llama-3.2-1B\"\n",
    "                                                    # \"meta-llama/Llama-3.2-3B\"\n",
    "                                                #     \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "                                                    # \"JackFram/llama-160m\"\n",
    "                                            #  , attn_implementation=\"sdpa\"\n",
    "                                            , attn_implementation=\"flash_attention_2\"\n",
    "                                             ,torch_dtype=dtype)\n",
    "source_model.eval()\n",
    "source_model.lm_head = torch.nn.Identity()\n",
    "reference_model = copy.deepcopy(source_model)\n",
    "# reference_model = source_model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = source_model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "armt_config = dict(\n",
    "    # segment_size=512,\n",
    "    segment_size=1024,\n",
    "    # segment_size=2048,\n",
    "    # segment_size=4096,\n",
    "    num_mem_tokens=128,\n",
    "    d_mem=64,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "armt_model = wrap_model_with_armt(source_model, **armt_config)\n",
    "armt_model.to(\"cuda\")\n",
    "\n",
    "torch.manual_seed(0)\n",
    "armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)\n",
    "armt_reference_model.to(\"cuda\")\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grouped_states = get_grouped_states(armt_model)\n",
    "# grouped_layer = make_grouped_layer_from_single_layer(\n",
    "#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)\n",
    "# # grouped_layer._grouped_execution = True\n",
    "# # grouped_layer._skip_associating = True\n",
    "# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "### TRAINABLE VERSION\n",
    "\n",
    "# grouped_layer = make_grouped_training_layer_from_single_layer(\n",
    "#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]),\n",
    "#     armt_model.memory_cell.model.model.layers\n",
    "# )\n",
    "# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)\n",
    "\n",
    "### AMORTIZIBLE VERSION\n",
    "\n",
    "# grouped_states = get_grouped_states(armt_model)\n",
    "# grouped_layer = make_grouped_layer_from_single_layer(\n",
    "#         copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)\n",
    "    \n",
    "# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)\n",
    "\n",
    "# batcher = GroupedBatcher(\n",
    "#     armt_grouped_model, \n",
    "#     n_layers=model_config.num_hidden_layers, \n",
    "#     seg_size=armt_config[\"segment_size\"]+armt_config[\"num_mem_tokens\"], \n",
    "#     hid_dim=model_config.hidden_size, \n",
    "#     pos_embed_dim=model_config.hidden_size\n",
    "# )\n",
    "# executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)\n",
    "\n",
    "### FAST LATENCY VERSION\n",
    "\n",
    "grouped_context = GroupedLayerContext()\n",
    "\n",
    "grouped_states = get_grouped_states(armt_model)\n",
    "grouped_layer = make_grouped_sliced_layer_from_single_layer(\n",
    "    grouped_context, copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states\n",
    ")\n",
    "armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)\n",
    "\n",
    "\n",
    "executor = FastGroupedArmtExecutor(\n",
    "    armt_grouped_model, \n",
    "    grouped_layer, \n",
    "    grouped_context, \n",
    "    model_config.num_hidden_layers, \n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jit compile As: [torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048])] Bs: [torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64])]\n",
      "\n",
      "// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8\n",
      "using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =\n",
      "  typename cutlass::gemm::kernel::DefaultGemmGrouped<\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
      "    cutlass::bfloat16_t, cutlass::layout::RowMajor,\n",
      "    float,\n",
      "    cutlass::arch::OpClassTensorOp,\n",
      "    cutlass::arch::Sm80,\n",
      "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
      "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
      "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
      "    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,\n",
      "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
      "    3,\n",
      "    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,\n",
      "    cutlass::arch::OpMultiplyAdd\n",
      ">::GemmKernel;\n",
      "\n",
      "// Define named type\n",
      "struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_type :\n",
      "  public cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base { };\n",
      "\n",
      "USE_EFFICIENT_ALLOCATION\n"
     ]
    }
   ],
   "source": [
    "### ONLY FOR FAST LATENCY VERSION\n",
    "\n",
    "# compile full layers\n",
    "segments_input = torch.rand((model_config.num_hidden_layers, armt_config['segment_size'], model_config.hidden_size), device=\"cuda\", dtype=dtype)\n",
    "\n",
    "i, j = 0, model_config.num_hidden_layers\n",
    "grouped_context.start_idx = i\n",
    "grouped_context.end_idx = j\n",
    "grouped_context.is_full = True\n",
    "\n",
    "ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "grouped_layer.generate_mode = True\n",
    "armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "\n",
    "# del ao\n",
    "# del segments_input\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 1024, 2048])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "segments_input[i:j].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4096, 8192, 16384, 32768, 65536, 131072\n",
    "num_segments = 32768//armt_config[\"segment_size\"]\n",
    "input_ids = torch.randint(\n",
    "    0, 10000, \n",
    "    (1, num_segments*armt_config[\"segment_size\"]), \n",
    "    dtype=torch.long, \n",
    "    device=\"cuda\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 32768])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.3 s, sys: 55.1 ms, total: 2.36 s\n",
      "Wall time: 1.83 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# %%timeit\n",
    "\n",
    "with torch.no_grad():\n",
    "    # armt_reference_model.memory_cell.zero_mem()\n",
    "    armt_reference_model.memory_cell.generate_mode(False)\n",
    "    reference_output = armt_reference_model.forward(input_ids)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# del reference_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 669 ms, sys: 3.3 ms, total: 673 ms\n",
      "Wall time: 672 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# %%timeit\n",
    "\n",
    "with torch.no_grad():\n",
    "    output = executor.forward(input_ids, skip_concat=False)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# del output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0103, device='cuda:0')"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# del output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer 0: rel_norm_mem: 0.0, rel_norm_zvalue: 0.000820159912109375\n",
      "Layer 1: rel_norm_mem: 0.0, rel_norm_zvalue: 0.005950927734375\n",
      "Layer 2: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00567626953125\n",
      "Layer 3: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0086669921875\n",
      "Layer 4: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01544189453125\n",
      "Layer 5: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0142822265625\n",
      "Layer 6: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00787353515625\n",
      "Layer 7: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0211181640625\n",
      "Layer 8: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01312255859375\n",
      "Layer 9: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01544189453125\n",
      "Layer 10: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00909423828125\n",
      "Layer 11: rel_norm_mem: 0.0, rel_norm_zvalue: 0.015869140625\n",
      "Layer 12: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0084228515625\n",
      "Layer 13: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01214599609375\n",
      "Layer 14: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00836181640625\n",
      "Layer 15: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00982666015625\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(armt_reference_model.memory_cell.model.model.layers)):\n",
    "    armt_ref_mem = armt_reference_model.memory_cell.model.model.layers[i].W_mem\n",
    "    armt_grouped_mem = executor.armt_model.memory_cell.model.model.layers[0].W_mem[i]\n",
    "    \n",
    "    rel_norm_mem = torch.norm(armt_ref_mem-armt_grouped_mem)/(torch.norm(armt_ref_mem)+1e-6)\n",
    "    \n",
    "    armt_ref_zvalue = armt_reference_model.memory_cell.model.model.layers[i].z\n",
    "    armt_grouped_zvalue = executor.armt_model.memory_cell.model.model.layers[0].z[i]\n",
    "    \n",
    "    rel_norm_zvalue = torch.norm(armt_ref_zvalue-armt_grouped_zvalue)/(torch.norm(armt_ref_zvalue)+1e-6)\n",
    "    \n",
    "    print(f\"Layer {i}: rel_norm_mem: {rel_norm_mem}, rel_norm_zvalue: {rel_norm_zvalue}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### this way you can \"batch\" several inputs to amortize the cost of the batcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n",
      "torch.Size([16, 128, 1]) torch.Size([16, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "### ONLY FOR AMORTIZABLE VERSION\n",
    "\n",
    "output_list = executor.forward([input_ids, input_ids])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(output_list[0].logits, output_list[1].logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kernel_armt",
   "language": "python",
   "name": "conda-env-conda_rmt_it_venv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
