{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('..')\n",
    "sys.path.append(\"./associative-recurrent-memory-transformer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "from grouped_batching.llama1b_grouping import (\n",
    "    wrap_model_with_armt, get_grouped_states, \n",
    "    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,\n",
    "    make_grouped_sliced_layer_from_single_layer\n",
    ")\n",
    "from grouped_batching.fast_executor import (\n",
    "    FastGroupedArmtExecutor, GroupedLayerContext, \n",
    "    associate_with_context, update_mem_with_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from grouped_batching.universal_grouping import (\n",
    "    extract_params_from_module, get_universal_grouped_states,\n",
    "    get_module_by_path, make_universal_grouped_layer,\n",
    "    set_module_by_path, make_universal_grouped_model\n",
    ")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtype = torch.bfloat16\n",
    "torch.set_default_dtype(dtype)\n",
    "torch.set_grad_enabled(False)\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import traceback\n",
    "\n",
    "def add_trace_to_forward(module, msg):\n",
    "    original_forward = module.forward\n",
    "\n",
    "    def traced_forward(*args, **kwargs):\n",
    "        print(f\"\\n[TRACE] {msg} {module.__class__.__name__}.forward called from:\\n\" + ''.join(traceback.format_stack(limit=10)))\n",
    "        return original_forward(*args, **kwargs)\n",
    "\n",
    "    module.forward = traced_forward\n",
    "    return module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import copy\n",
    "\n",
    "MODEL_TYPE = \"gpt2\" # \"gpt2\" \"llama\"\n",
    "\n",
    "\n",
    "if MODEL_TYPE == \"gpt2\":\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125m\")\n",
    "    source_model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/gpt-neo-125m\"\n",
    "                                                , attn_implementation=\"flash_attention_2\"\n",
    "                                                ,torch_dtype=dtype)\n",
    "else:\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"JackFram/llama-160m\")\n",
    "    source_model = AutoModelForCausalLM.from_pretrained(\"JackFram/llama-160m\"\n",
    "                                                , attn_implementation=\"flash_attention_2\"\n",
    "                                                ,torch_dtype=dtype)\n",
    "\n",
    "# Replace all LayerNorm modules in the model with LayerNorm without weights and biases\n",
    "# for name, module in source_model.named_modules():\n",
    "#     if isinstance(module, nn.LayerNorm):\n",
    "#         # Create a new LayerNorm with elementwise_affine=False (no weights and biases)\n",
    "#         new_layernorm = nn.LayerNorm(\n",
    "#             normalized_shape=module.normalized_shape,\n",
    "#             eps=module.eps,\n",
    "#             elementwise_affine=False\n",
    "#         )\n",
    "        \n",
    "#         # Get the parent module and attribute name to replace the LayerNorm\n",
    "#         parent_name = '.'.join(name.split('.')[:-1])\n",
    "#         child_name = name.split('.')[-1]\n",
    "        \n",
    "#         if parent_name:\n",
    "#             parent = source_model\n",
    "#             for part in parent_name.split('.'):\n",
    "#                 parent = getattr(parent, part)\n",
    "#             setattr(parent, child_name, new_layernorm)\n",
    "#         else:\n",
    "#             setattr(source_model, child_name, new_layernorm)\n",
    "\n",
    "# for l in source_model.transformer.h:\n",
    "    # l.mlp = nn.Identity()\n",
    "    # l.attn.attention.out_proj = nn.Identity()\n",
    "    # l.attn.attention = nn.Identity()\n",
    "\n",
    "# source_model.transformer.h = source_model.transformer.h[:1]\n",
    "\n",
    "# source_model.transformer.wpe.weight.fill_(0)\n",
    "# source_model.transformer.wpe.weight.data = torch.tensor([3,3,3], dtype=dtype)\n",
    "# source_model.transformer.wte.weight.fill_(0)\n",
    "# source_model.transformer.wte.weight.data = torch.tensor([3,3,3], dtype=dtype)\n",
    "# add_trace_to_forward(source_model.transformer.wpe, 'wpe')\n",
    "# add_trace_to_forward(source_model.transformer.wte, 'wte')\n",
    "\n",
    "\n",
    "# source_model.ln_f = source_model.transformer.ln_f\n",
    "# source_model.transformer.ln_f = nn.Identity()\n",
    "# source_model.lm_head = nn.Identity()\n",
    "reference_model = copy.deepcopy(source_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoForCausalLM(\n",
       "  (transformer): GPTNeoModel(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(2048, 768)\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x GPTNeoBlock(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPTNeoAttention(\n",
       "          (attention): GPTNeoFlashAttention2(\n",
       "            (attn_dropout): Dropout(p=0.0, inplace=False)\n",
       "            (resid_dropout): Dropout(p=0.0, inplace=False)\n",
       "            (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPTNeoMLP(\n",
       "          (c_fc): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (c_proj): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "armt_config = dict(\n",
    "    segment_size=1024,\n",
    "    num_mem_tokens=128,\n",
    "    d_mem=64,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoForCausalLM(\n",
       "  (transformer): GPTNeoModel(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(2048, 768)\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x GPTNeoBlock(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPTNeoAttention(\n",
       "          (attention): GPTNeoFlashAttention2(\n",
       "            (attn_dropout): Dropout(p=0.0, inplace=False)\n",
       "            (resid_dropout): Dropout(p=0.0, inplace=False)\n",
       "            (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPTNeoMLP(\n",
       "          (c_fc): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (c_proj): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if MODEL_TYPE == \"gpt2\":\n",
    "    layers_attr = 'transformer.h'\n",
    "else:\n",
    "    layers_attr = 'model.layers'\n",
    "\n",
    "torch.manual_seed(0)\n",
    "armt_model = wrap_model_with_armt(source_model, **armt_config, layers_attr=layers_attr)\n",
    "armt_model.to(\"cuda\")\n",
    "\n",
    "torch.manual_seed(0)\n",
    "armt_reference_model = wrap_model_with_armt(reference_model, **armt_config, layers_attr=layers_attr)\n",
    "armt_reference_model.to(\"cuda\")\n",
    ";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "if MODEL_TYPE == \"gpt2\":\n",
    "    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.transformer.h)\n",
    "else:\n",
    "    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.model.layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['W_mq.weight',\n",
       " 'W_mk.weight',\n",
       " 'W_mv.weight',\n",
       " 'W_mb.weight',\n",
       " 'W_mb.bias',\n",
       " 'layer.ln_1.weight',\n",
       " 'layer.ln_1.bias',\n",
       " 'layer.attn.attention.k_proj.weight',\n",
       " 'layer.attn.attention.v_proj.weight',\n",
       " 'layer.attn.attention.q_proj.weight',\n",
       " 'layer.attn.attention.out_proj.weight',\n",
       " 'layer.attn.attention.out_proj.bias',\n",
       " 'layer.ln_2.weight',\n",
       " 'layer.ln_2.bias',\n",
       " 'layer.mlp.c_fc.weight',\n",
       " 'layer.mlp.c_fc.bias',\n",
       " 'layer.mlp.c_proj.weight',\n",
       " 'layer.mlp.c_proj.bias',\n",
       " 'W_mem',\n",
       " 'z']"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(grouped_params.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_context = GroupedLayerContext()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBSTITUTE efficient module_path='W_mq': len(weight_values)=12 None \n",
      "SUBSTITUTE efficient module_path='W_mk': len(weight_values)=12 None \n",
      "SUBSTITUTE efficient module_path='W_mv': len(weight_values)=12 None \n",
      "SUBSTITUTE naive module_path='W_mb': len(weight_values)=12 12 \n",
      "SUBSTITUTE efficient module_path='layer.attn.attention.k_proj': len(weight_values)=12 None \n",
      "SUBSTITUTE efficient module_path='layer.attn.attention.v_proj': len(weight_values)=12 None \n",
      "SUBSTITUTE efficient module_path='layer.attn.attention.q_proj': len(weight_values)=12 None \n",
      "SUBSTITUTE efficient module_path='layer.attn.attention.out_proj': len(weight_values)=12 12 \n",
      "SUBSTITUTE efficient module_path='layer.mlp.c_fc': len(weight_values)=12 12 \n",
      "SUBSTITUTE efficient module_path='layer.mlp.c_proj': len(weight_values)=12 12 \n",
      "SUBSTITUTE norm_path='layer.ln_1': torch.Size([12, 768]) torch.Size([12, 768]) \n",
      "SUBSTITUTE norm_path='layer.ln_2': torch.Size([12, 768]) torch.Size([12, 768]) \n"
     ]
    }
   ],
   "source": [
    "if MODEL_TYPE == \"gpt2\":\n",
    "    layer_base = copy.deepcopy(armt_model.memory_cell.model.transformer.h[0])\n",
    "else:\n",
    "    layer_base = copy.deepcopy(armt_model.memory_cell.model.model.layers[0])\n",
    "\n",
    "grouped_layer = make_universal_grouped_layer(\n",
    "    grouped_context, \n",
    "    layer_base,\n",
    "    grouped_params,\n",
    "    use_layer_norm=(MODEL_TYPE == \"gpt2\"), # layer norm for gpt2, rms norm for llama\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoForCausalLM(\n",
       "  (transformer): GPTNeoModel(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(2048, 768)\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x AssociativeLayerWrapper(\n",
       "        (W_mq): Linear(in_features=768, out_features=64, bias=False)\n",
       "        (W_mk): Linear(in_features=768, out_features=64, bias=False)\n",
       "        (W_mv): Linear(in_features=768, out_features=768, bias=False)\n",
       "        (W_mb): Linear(in_features=768, out_features=1, bias=True)\n",
       "        (layer): GPTNeoBlock(\n",
       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): GPTNeoAttention(\n",
       "            (attention): GPTNeoFlashAttention2(\n",
       "              (attn_dropout): Dropout(p=0.0, inplace=False)\n",
       "              (resid_dropout): Dropout(p=0.0, inplace=False)\n",
       "              (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "          )\n",
       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): GPTNeoMLP(\n",
       "            (c_fc): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (c_proj): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (act): NewGELUActivation()\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([12, 384, 768]), torch.Size([12, 384]))"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grouped_layer.W_mem.data.shape, grouped_layer.z.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from grouped_batching.universal_executor import UniversalGroupedExecutor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPTNeoForCausalLM(\n",
       "  (transformer): GPTNeoModel(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(2048, 768)\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x AssociativeLayerWrapper(\n",
       "        (W_mq): Linear(in_features=768, out_features=64, bias=False)\n",
       "        (W_mk): Linear(in_features=768, out_features=64, bias=False)\n",
       "        (W_mv): Linear(in_features=768, out_features=768, bias=False)\n",
       "        (W_mb): Linear(in_features=768, out_features=1, bias=True)\n",
       "        (layer): GPTNeoBlock(\n",
       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): GPTNeoAttention(\n",
       "            (attention): GPTNeoFlashAttention2(\n",
       "              (attn_dropout): Dropout(p=0.0, inplace=False)\n",
       "              (resid_dropout): Dropout(p=0.0, inplace=False)\n",
       "              (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (v_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (q_proj): Linear(in_features=768, out_features=768, bias=False)\n",
       "              (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "          )\n",
       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): GPTNeoMLP(\n",
       "            (c_fc): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (c_proj): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (act): NewGELUActivation()\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "if MODEL_TYPE == \"gpt2\":\n",
    "    layers_path = \"memory_cell.model.transformer.h\"\n",
    "else:\n",
    "    layers_path = \"memory_cell.model.model.layers\"\n",
    "\n",
    "armt_grouped_model, source_model_layers = make_universal_grouped_model(\n",
    "    armt_model, \n",
    "    grouped_layer,\n",
    "    layers_path=layers_path\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def preprocess_segment_gpt2(model, input_segment):\n",
    "    pos_ids = torch.arange(\n",
    "        0, \n",
    "        input_segment.shape[0], \n",
    "        device=input_segment.device\n",
    "    )\n",
    "    wpe_emb = model.memory_cell.model.transformer.wpe(pos_ids)\n",
    "    input_segment.add_(wpe_emb)\n",
    "    return input_segment\n",
    "\n",
    "grouped_compute_gpt2 = None # uses default grouped layer\n",
    "\n",
    "def postprocess_segment_gpt2(model, output_segment):\n",
    "    out = model.memory_cell.model.transformer.ln_f(output_segment)\n",
    "    out_tokens = model.memory_cell.model.lm_head(out)\n",
    "    return out_tokens\n",
    "\n",
    "\n",
    "\n",
    "preprocess_segment_llama = None\n",
    "\n",
    "def grouped_compute_llama(model, grouped_layer, grouped_input_tensor):\n",
    "    position_ids = torch.arange(\n",
    "        0, \n",
    "        grouped_input_tensor.shape[1], \n",
    "        device=grouped_input_tensor.device\n",
    "    ).unsqueeze(0)\n",
    "    position_embeddings = model.memory_cell.model.model.rotary_emb(grouped_input_tensor, position_ids)\n",
    "    batch_output = grouped_layer.forward(grouped_input_tensor, position_embeddings=position_embeddings)\n",
    "    return batch_output\n",
    "    \n",
    "\n",
    "def postprocess_segment_llama(model, output_segment):\n",
    "    out = model.memory_cell.model.model.norm(output_segment)\n",
    "    out_tokens = model.memory_cell.model.lm_head(out)\n",
    "    return out_tokens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "if MODEL_TYPE == \"gpt2\":\n",
    "    executor = UniversalGroupedExecutor(\n",
    "        model=armt_grouped_model,\n",
    "        grouped_layer=grouped_layer,\n",
    "        context=grouped_context,\n",
    "        n_layers=source_model.config.num_hidden_layers,\n",
    "        model_path=\"memory_cell.model.transformer\",\n",
    "        out_norm_attr=\"memory_cell.model.transformer.ln_f\",\n",
    "        lm_head_path=\"memory_cell.model.lm_head\",\n",
    "        memory_path=\"memory_cell\",\n",
    "        preprocess_segment_fn = preprocess_segment_gpt2,\n",
    "        postprocess_segment_fn = postprocess_segment_gpt2,\n",
    "        grouped_compute_fn = grouped_compute_gpt2,\n",
    "    )\n",
    "else:\n",
    "    executor = UniversalGroupedExecutor(\n",
    "        model=armt_grouped_model,\n",
    "        grouped_layer=grouped_layer,\n",
    "        context=grouped_context,\n",
    "        n_layers=source_model.config.num_hidden_layers,\n",
    "        model_path=\"memory_cell.model.model\",\n",
    "        out_norm_attr=\"memory_cell.model.model.norm\",\n",
    "        lm_head_path=\"memory_cell.model.lm_head\",\n",
    "        memory_path=\"memory_cell\",\n",
    "        preprocess_segment_fn = preprocess_segment_llama,\n",
    "        postprocess_segment_fn = postprocess_segment_llama,\n",
    "        grouped_compute_fn = grouped_compute_llama,\n",
    "    )\n",
    "\n",
    "executor.vanilla_model = armt_reference_model\n",
    "\n",
    "# executor = FastGroupedArmtExecutor(\n",
    "#     armt_grouped_model, \n",
    "#     grouped_layer, \n",
    "#     grouped_context, \n",
    "#     source_model.config.num_hidden_layers\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "### ONLY FOR FAST LATENCY VERSION\n",
    "\n",
    "# compile full layers\n",
    "segments_input = torch.rand((source_model.config.num_hidden_layers, armt_config['segment_size'], source_model.config.hidden_size), device=\"cuda\", dtype=dtype)\n",
    "\n",
    "i, j = 0, source_model.config.num_hidden_layers\n",
    "grouped_context.start_idx = i\n",
    "grouped_context.end_idx = j\n",
    "grouped_context.is_full = True\n",
    "\n",
    "ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "grouped_layer.generate_mode = True\n",
    "if MODEL_TYPE == \"gpt2\":\n",
    "    armt_grouped_model.memory_cell.model.transformer(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "else:\n",
    "    armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)\n",
    "update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])\n",
    "\n",
    "# del ao\n",
    "# del segments_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4096, 8192, 16384, 32768, 65536, 131072\n",
    "num_segments = 4096//armt_config[\"segment_size\"]\n",
    "input_ids = torch.randint(\n",
    "    0, 10000, \n",
    "    (1, num_segments*armt_config[\"segment_size\"]), \n",
    "    dtype=torch.long, \n",
    "    device=\"cuda\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4096])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Forward (prefill)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 581 ms, sys: 6.47 ms, total: 588 ms\n",
      "Wall time: 99.3 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# %%timeit\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # armt_reference_model.memory_cell.zero_mem()\n",
    "    armt_reference_model.memory_cell.generate_mode(False)\n",
    "    reference_output = armt_reference_model.forward(input_ids)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 60.4 ms, sys: 3.83 ms, total: 64.3 ms\n",
      "Wall time: 67 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# %%timeit\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "with torch.no_grad():\n",
    "    output = executor.forward(input_ids, skip_concat=False)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0332, device='cuda:0')"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_kwargs = {\n",
    "    'max_new_tokens': 20,\n",
    "    'pad_token_id': 0,\n",
    "    'eos_token_id': 1,\n",
    "    'attention_mask': None,\n",
    "    # 'attention_mask': torch.tril(torch.ones((1024, 1024), device='cuda', dtype=bool))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_out_ref = armt_reference_model.generate(input_ids, **generate_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  13,  405, 1065,   13,  198,  198,    7,   64,    8,   77,  375,  198,\n",
       "            8,   68,   13,   68,   13,   68,   13,   68]], device='cuda:0')"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen_out_ref"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3072]) torch.Size([1, 1024])\n"
     ]
    }
   ],
   "source": [
    "gen_out = executor.generate(input_ids, seg_size=1024+128, **generate_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[  13,  405, 1065,   13,  198,  198,    7,   64,    8,   77,  375,  198,\n",
       "             8,   68,   13,   68,   13,   68,   13,   68]], device='cuda:0'),\n",
       " 0.00043964385986328125)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True, device='cuda:0')"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(gen_out_ref == gen_out[0]).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "armt_kernel",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
