{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9b41de05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # Change \"0\" to the index of the GPU you want to use\n",
    "\n",
    "print(torch.cuda.device_count())  # This should print 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "299c3029",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7fc7fc160c10>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_from_disk, Dataset\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from matplotlib import pyplot as plt\n",
    "from trl import DataCollatorForCompletionOnlyLM\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv\n",
    "import math\n",
    "import pickle\n",
    "import numpy as np\n",
    "from contextualization import compute_overlap_area_torch\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6455106",
   "metadata": {},
   "source": [
    "# Get Model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fce5fdd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_name = \"unsloth/Llama-3.2-3B-Instruct\"\n",
    "# model_name = \"NousResearch/Llama-3.2-1B\"\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name, \n",
    "    # device_map = \"auto\",\n",
    "    # torch_dtype = torch.bfloat16,\n",
    ")\n",
    "model.to(\"cuda:0\")\n",
    "# with open(\"../tokenizer_unsloth.json\", \"r\") as f:\n",
    "#     tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer_file=f)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_name, \n",
    "    # tokenizer_file=\"../tokenizer.json\",\n",
    "    # \"NousResearch/Meta-Llama-3.1-8B-Instruct\",\n",
    "    # use_fast=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61586dbf",
   "metadata": {},
   "source": [
    "# Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "69893ae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# max_seq_length = 5500\n",
    "max_seq_length = tokenizer.model_max_length\n",
    "batch_size = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5d68b80",
   "metadata": {},
   "source": [
    "## Sqad v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8877da77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['text'],\n",
      "    num_rows: 10000\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'content': \"You are tasked with answering questions based on the given context. Provide a precise, direct answer based solely on the information in the context. Do not repeat the question or the context. If the answer cannot be found, return 'NOT FOUND.'\\n\\nContext:\\nIn Islam, dogs are viewed as unclean because they are viewed as scavengers. In 2015 city councillor Hasan Küçük of The Hague called for dog ownership to be made illegal in that city. Islamic activists in Lérida, Spain, lobbied for dogs to be kept out of Muslim neighborhoods, saying their presence violated Muslims' religious freedom. In Britain, police sniffer dogs are carefully used, and are not permitted to contact passengers, only their luggage. They are required to wear leather dog booties when searching mosques or Muslim homes.\\n\\nQuestion: How are dogs viewed in Islam?\",\n",
       "  'role': 'user'},\n",
       " {'content': 'Dogs are viewed as unclean in Islam due to their association with scavenging.',\n",
       "  'role': 'assistant'}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_path = \"<path to dataset>\"\n",
    "data_name = 'squad'\n",
    "dataset = Dataset.from_json(data_path)\n",
    "print(dataset)\n",
    "train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
    "train_dataset = train_test_split['train']\n",
    "valid_dataset = train_test_split['test']\n",
    "train_dataset[0]['text']['generated_text']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14a2a0ff",
   "metadata": {},
   "source": [
    "## 2Wiki"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d8bc749f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['text'],\n",
      "    num_rows: 10000\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'content': 'You are tasked with answering questions based on the given context. Provide a precise, direct answer based solely on the information in the context. Do not repeat the question or the context. If the answer cannot be found, return \\'NOT FOUND.\\'\\n\\nContext:\\nAna María Vignoli: Ana María Vignoli or Ana Maria Vignoli( born 27 July 1945) is a Uruguayan former minister of Social Development.\\nPreston Sturges: Preston Sturges( born Edmund Preston Biden; August 29, 1898 – August 6, 1959) was an American playwright, screenwriter, and film director. In 1941, he won the Oscar for Best Original Screenplay for the film\" The Great McGinty\", his first of three nominations in the category. Sturges took the screwball comedy format of the 1930s to another level, writing dialogue that, heard today, is often surprisingly naturalistic, mature, and ahead of its time, despite the farcical situations. It is not uncommon for a Sturges character to deliver an exquisitely turned phrase and take an elaborate pratfall within the same scene. Prior to Sturges, other figures in Hollywood( such as Charlie Chaplin, D. W. Griffith, and Frank Capra) had directed films from their own scripts, however Sturges is often regarded as the first Hollywood figure to establish success as a screenwriter and then move into directing his own scripts, at a time when those roles were separate. Sturges famously sold the story for\" The Great McGinty\" to Paramount Pictures for$ 1, in return for being allowed to direct the film.\\nRumbi Katedza: Rumbi Katedza is a Zimbabwean Film Producer and Director who was born on 17 January 1974.\\nChristine McIntyre: Christine Cecilia McIntyre( April 16, 1911 – July 8, 1984) was an American actress and singer who appeared in various films in the 1930s and 1940s. She is mainly remembered as the beautiful blonde actress who appeared in many of The Three Stooges shorts produced by Columbia Pictures.\\nAna Arabia: Ana Arabia( I Am Arab) is a 2013 French- Israeli drama film written and directed by Amos Gitai. It was entered into the main competition at the 70th Venice International Film Festival. It consists of a single long take.\\nThe Beautiful Blonde from Bashful Bend: The Beautiful Blonde from Bashful Bend is a 1949 romantic comedy Western film starring Betty Grable and featuring Cesar Romero and Rudy Vallee. It was directed by Preston Sturges and written by him based on a story by Earl Felton. The film, Sturges\\' first Technicolor production, was not well received at the time it was released, and was generally conceded to be a disaster – even Betty Grable bad-mouthed it – but its reputation has improved somewhat over time, though it is not considered to be in the same league as the intelligent comedies Sturges made at Paramount Pictures for which he is known. \"The Beautiful Blonde from Bashful Bend\" would turn out to be the last American film on which Sturges would work – although he would receive credit for films that were remakes or adaptations of his earlier movies. Sturges directed only one more film in his life, the 1955 French comedy \"Les carnets du Major Thompson\" (released in the U.S. as \"The French, They Are a Funny Race\").\\nThe Blonde from Singapore: The Blonde from Singapore( also released as Hot Pearls) is a 1941 American adventure film directed by Edward Dmytryk.\\nAmos Gitai: Amos Gitai( born 11 October 1950) is an Israeli filmmaker, who was trained as an architect. Gitai\\'s work was presented in several major retrospectives in Pompidou Center Paris, the Museum of Modern Art( MoMA) New York, Lincoln Center New York, and the British Film Institute London. To date Amos Gitai has created over 90 works of art throughout 38 years. Between 1999 and 2017 ten of his films were entered in the Cannes Film Festival for the Palme d\\' Or as well as The Venice International Film Festival for the Golden Lion award. He has worked with Juliette Binoche, Jeanne Moreau, Natalie Portman, Yael Abecassis, Samuel Fuller, Hanna Schygulla, Annie Lennox, Barbara Hendricks, Léa Seydoux, Valeria Bruni Tedeschi, Henri Alekan, Renato Berta, Nurith Aviv, Éric Gautier and more. Since 2000 he has been collaborating with the French screenwriter Marie- José Sanselme. He received several prestigious prizes, in particular the Leopard of Honor at the Locarno International Film Festival( 2008), the Roberto Rossellini prize( 2005), the Robert Bresson prize( 2013), the Paradjanov prize( 2014), and Légion d\\'Honneur( 2017). In 2018, Amos Gitai has been elected professor at the chair of artistic creation at the Collège de France, with a series of 12 lessons on cinema( 16 October – 18 December 2018)\\nHassan Zee: Hassan Zee is a Pakistani- American film director who was born in Chakwal, Pakistan.\\nEdward Yates: Edward J. Yates( September 16, 1918 – June 2, 2006) was an American television director who was the director of the ABC television program\" American Bandstand\" from 1952 until 1969.\\n\\nQuestion: Which film has the director who was born earlier, The Beautiful Blonde From Bashful Bend or Ana Arabia?',\n",
       "  'role': 'user'},\n",
       " {'content': \"The Beautiful Blonde From Bashful Bend (1949) and Ana Arabia (2013) films' directors, Preston Sturges and Amos Gitai, were born in 1898 and 1950, respectively. Therefore, Preston Sturges was born earlier.\",\n",
       "  'role': 'assistant'}]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_path = \"<path to dataset>\"\n",
    "data_name = '2WikiMultihopQA'\n",
    "dataset = Dataset.from_json(data_path)\n",
    "print(dataset)\n",
    "train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
    "train_dataset = train_test_split['train']\n",
    "valid_dataset = train_test_split['test']\n",
    "train_dataset[0]['text']['generated_text']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15c471d0",
   "metadata": {},
   "source": [
    "## QM sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f180ffe5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['text'],\n",
      "    num_rows: 1095\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'content': 'You are a helpful assistant. Summarize the following meeting transcript as clearly and concisely as possible.\\n\\nTranscript:\\nMark Reckless AM: What are those messages on why prospective students should study in Wales ? \\nDr David Blaney: One of them in particular is relative safety We know that one of the considerations particularly for parents of overseas students is are they going to go to a safe environment and we know that the perception of international students who study in Wales is that this is a comfortable and safe place to be That is partly a function of the size of our larger cities—quite a lot smaller than many of the cities in England So that is a key message Being part of a UK system is also an important message there as well So we have got a UKquality system a UK degree and the strength of that brand is available in Wales but it is available in a way that is safer and more supportive I think is the messaging that is coming through \\nJohn Griffiths AM: We would better move on I think had not we ? Darren then',\n",
       "  'role': 'user'},\n",
       " {'content': \"Here is a summary of the meeting transcript:\\n\\nThe conversation centered on promoting Wales as a destination for prospective international students. Key points discussed included:\\n\\n* Wales is considered a safe environment, which is a major concern for parents of overseas students.\\n* The country's smaller cities are perceived as being more comfortable and safe compared to many English cities.\\n* Studying in Wales provides a UK-quality education with a UK degree, which is a strong brand.\\n* The messaging emphasizes that Wales offers a safe and supportive environment for international students.\",\n",
       "  'role': 'assistant'}]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_path = \"<path to dataset>\"\n",
    "data_name = 'qmsum'\n",
    "dataset = Dataset.from_json(data_path)\n",
    "print(dataset)\n",
    "train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
    "train_dataset = train_test_split['train']\n",
    "valid_dataset = train_test_split['test']\n",
    "train_dataset[0]['text']['generated_text']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d666400",
   "metadata": {},
   "source": [
    "# Prepair dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "13c7cd6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenizer.chat_template = \"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\"\n",
    "tokenizer.chat_template = \"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}\"\n",
    "def prepare_dataset(dataset, n_proc=20):\n",
    "    return dataset.map(\n",
    "        lambda x: tokenizer(tokenizer.apply_chat_template(x['text']['generated_text'], tokenize=False), \n",
    "                            add_special_tokens=False,\n",
    "                            truncation=True,\n",
    "                            padding=False,\n",
    "                            max_length=max_seq_length,\n",
    "                            return_overflowing_tokens=False,\n",
    "                            return_length=False),\n",
    "        remove_columns=set(dataset.column_names) - set(['input_ids', 'attention_mask']),\n",
    "        num_proc=n_proc,\n",
    "    )\n",
    "\n",
    "tokenized_train_ds = prepare_dataset(train_dataset, n_proc=20)\n",
    "tokenized_valid_ds = prepare_dataset(valid_dataset, n_proc=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "34e3fec7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of batches in train dataloader: 876\n",
      "Number of batches in valid dataloader: 219\n"
     ]
    }
   ],
   "source": [
    "data_collator = DataCollatorForCompletionOnlyLM(\n",
    "                    # instruction_template = \"<|im_start|>user\", \n",
    "                    instruction_template = \"<|start_header_id|>user<|end_header_id|>\", \n",
    "                    # response_template = \"<|im_start|>assistant\", \n",
    "                    response_template = \"<|start_header_id|>assistant<|end_header_id|>\",\n",
    "                    tokenizer = tokenizer, \n",
    "                    mlm = False)\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    tokenized_train_ds,\n",
    "    batch_size=batch_size,  # Adjust the batch size as needed\n",
    "    shuffle=True,\n",
    "    collate_fn=data_collator\n",
    ")\n",
    "print(f\"Number of batches in train dataloader: {len(train_dataloader)}\")\n",
    "\n",
    "valid_dataloader = DataLoader(\n",
    "    tokenized_valid_ds,\n",
    "    batch_size=batch_size,  # Adjust the batch size as needed\n",
    "    shuffle=False,\n",
    "    collate_fn=data_collator\n",
    ")\n",
    "print(f\"Number of batches in valid dataloader: {len(valid_dataloader)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf97871d",
   "metadata": {},
   "source": [
    "# Utils to get activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "39c1862f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "head_dim:128, n_heads:24, n_kv_heads:8\n"
     ]
    }
   ],
   "source": [
    "head_dim = model.config.head_dim\n",
    "n_heads = model.config.num_attention_heads\n",
    "n_kv_heads = model.config.num_key_value_heads\n",
    "print(f'head_dim:{head_dim}, n_heads:{n_heads}, n_kv_heads:{n_kv_heads}')\n",
    "\n",
    "def get_input_output(vectors, input_mask, output_mask):\n",
    "    vectors_input = vectors[:, :, input_mask, :]\n",
    "    vectors_output = vectors[:, :, output_mask, :]\n",
    "    return vectors_input, vectors_output\n",
    "\n",
    "def get_info(batch, return_activations=False):\n",
    "    batch = {k: v.to(model.device) for k, v in batch.items()}\n",
    "    activations = []  # List to store QKV activations for each layer\n",
    "    overlap_areas = []\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            **batch,\n",
    "            # past_key_values=dynamic_cache,  # Pass the dynamic cache\n",
    "            # use_cache=True,\n",
    "            # return_dict=True,\n",
    "            output_hidden_states=True,\n",
    "            # output_attentions=True,\n",
    "        )\n",
    "    \n",
    "        hidden_states = outputs.hidden_states  # Hidden states: list of tensors for each layer\n",
    "        # hidden_states[0] corresponds to the input token embeddings\n",
    "        bsz, seqlen, _ = hidden_states[0].shape  # Shape: [batch_size, seq_len, hidden_size]\n",
    "        # print(seqlen)\n",
    "        # position embedding\n",
    "        position_ids = torch.arange(seqlen).unsqueeze(0).to(model.device)\n",
    "        cos, sin = model.model.rotary_emb(hidden_states[0], position_ids)\n",
    "\n",
    "        for i, layer_hidden_state in enumerate(hidden_states[:-1]):  # Exclude the final output layer\n",
    "            layer = model.model.layers[i]\n",
    "            nmz_hs = layer.input_layernorm(layer_hidden_state)\n",
    "            \n",
    "            q = layer.self_attn.q_proj(nmz_hs)  # Query: [batch_size, seq_len, hidden_size]\n",
    "            k = layer.self_attn.k_proj(nmz_hs)  # Key: [batch_size, seq_len, hidden_size]\n",
    "            v = layer.self_attn.v_proj(nmz_hs)  # Value: [batch_size, seq_len, hidden_size]\n",
    "            \n",
    "            q = q.view(bsz, seqlen, n_heads, head_dim).transpose(1, 2)\n",
    "            k = k.view(bsz, seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "            v = v.view(bsz, seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "            \n",
    "            qr, kr = apply_rotary_pos_emb(q, k, cos, sin)\n",
    "\n",
    "            input_mask = batch['labels'][0] == -100\n",
    "            output_mask = batch['labels'][0] != -100\n",
    "            \n",
    "            q_input, q_output = get_input_output(qr, input_mask, output_mask)\n",
    "            k_input, k_output = get_input_output(kr, input_mask, output_mask)\n",
    "            v_input, v_output = get_input_output(v, input_mask, output_mask)\n",
    "            \n",
    "            k_input = repeat_kv(k_input, n_heads // n_kv_heads)\n",
    "            inter_qk = torch.matmul(q_output, k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            \n",
    "            k_output = repeat_kv(k_output, n_heads // n_kv_heads)\n",
    "            intra_qk = torch.matmul(q_output, k_output.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            # print(f\"Layer {i} inter_qk shape: {inter_qk.shape}; intra_qk shape: {intra_qk.shape}\")\n",
    "            \n",
    "            if return_activations:\n",
    "                activations.append({'k_input': k_input, 'k_output': k_output, \n",
    "                                    'q': q, \n",
    "                                    'nmz_hs': nmz_hs,\n",
    "                                    'q_input': q_input, 'q_output': q_output, \n",
    "                                    'v_input': v_input, 'v_output': v_output, \n",
    "                                    'h_input': layer_hidden_state[:, ~output_mask, :],\n",
    "                                    'h_output': layer_hidden_state[:, output_mask, :],\n",
    "                                    'inter_qk': inter_qk, 'intra_qk': intra_qk})\n",
    "            \n",
    "            layer_overlap_area = []\n",
    "            ltr_mask = torch.tril(torch.ones_like(intra_qk[0, 0, :, :], dtype=torch.bool))\n",
    "            for head in range(n_heads):\n",
    "                intra_qk_lower = intra_qk[0, head, :, :][ltr_mask]\n",
    "                overlap_area = compute_overlap_area_torch(intra_qk_lower.flatten(), inter_qk[:, head, :, :].flatten())\n",
    "                layer_overlap_area.append(overlap_area)\n",
    "            overlap_areas.append(layer_overlap_area)\n",
    "    if return_activations:\n",
    "        return activations, torch.tensor(overlap_areas)\n",
    "    return torch.tensor(overlap_areas)\n",
    "\n",
    "# sample_batch = next(iter(train_dataloader))\n",
    "# print(sample_batch)\n",
    "# activations, overlap_areas = get_info(sample_batch, return_activations=True)\n",
    "# # overlap_areas = get_info(sample_batch)\n",
    "# print(overlap_areas.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cecea384",
   "metadata": {},
   "source": [
    "# Analyse one Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1b3f1ad8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Here's a summary of the meeting transcript:\n",
      "\n",
      "The meeting discussed the design of a new TV with an LCD panel and jogdial. Key points include:\n",
      "\n",
      "* The LCD panel displays simple functions, such as brightness and contrast, and may have pictures or symbols.\n",
      "* The number pad is used for input selection and volume control.\n",
      "* The slogan will be displayed in a small space, likely below the buttons.\n",
      "* Button sizes are planned to be around 10 cm (4 inches) in height, with the power button being red and prominent.\n",
      "* The jogdial will be used for volume control and advanced functions, such as contrast and color.\n",
      "* The TV will have basic audio settings, including balance and preprogrammed sound modes.\n",
      "* Additional features may include input selection, sharpness, and sound settings that can be adjusted by the user.\n",
      "\n",
      "The designers discussed various aspects of the TV's design, including the LCD panel, number pad, and button sizes. They also considered the TV's functionality, including input selection, audio settings, and advanced functions.<|eot_id|>\n"
     ]
    }
   ],
   "source": [
    "for i, sample_batch in enumerate(train_dataloader):\n",
    "    if i == 0:\n",
    "        break\n",
    "sample_batch = next(iter(valid_dataloader))\n",
    "# sample_batch['labels'][0, -1] = sample_batch['input_ids'][0, -1]  # Set the last label to the last input id\n",
    "# print(sample_batch)\n",
    "print(tokenizer.decode(sample_batch['input_ids'][sample_batch['labels'] != -100]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a3901c97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overlap_area torch.Size([28, 24])\n",
      "Layer 0 activations:\n",
      "k_input: torch.Size([1, 24, 1003, 128])\n",
      "k_output: torch.Size([1, 24, 213, 128])\n",
      "q: torch.Size([1, 24, 1216, 128])\n",
      "nmz_hs: torch.Size([1, 1216, 3072])\n",
      "q_input: torch.Size([1, 24, 1003, 128])\n",
      "q_output: torch.Size([1, 24, 213, 128])\n",
      "v_input: torch.Size([1, 8, 1003, 128])\n",
      "v_output: torch.Size([1, 8, 213, 128])\n",
      "h_input: torch.Size([1, 1003, 3072])\n",
      "h_output: torch.Size([1, 213, 3072])\n",
      "inter_qk: torch.Size([1, 24, 213, 1003])\n",
      "intra_qk: torch.Size([1, 24, 213, 213])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "activations, overlap_areas = get_info(sample_batch, return_activations=True)\n",
    "# print(overlap_areas.shape)\n",
    "print('overlap_area', overlap_areas.shape)\n",
    "for i, layer_activations in enumerate(activations):\n",
    "    print(f\"Layer {i} activations:\")\n",
    "    for key, value in layer_activations.items():\n",
    "        print(f\"{key}: {value.shape}\")\n",
    "    print()\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1246c9f8",
   "metadata": {},
   "source": [
    "## With Error propagation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5754f88c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "overall_evictable = 0\n",
    "overall_cache = 0\n",
    "overall_compressibility = []\n",
    "overall_error_rate = []\n",
    "window = 64\n",
    "manual_cr = None\n",
    "for i, layer_activations in enumerate(activations):\n",
    "    k_input = layer_activations['k_input']\n",
    "    k_output = layer_activations['k_output']\n",
    "    q_input = layer_activations['q_input']\n",
    "    q_output = layer_activations['q_output']\n",
    "    v_input = repeat_kv(layer_activations['v_input'], n_heads // n_kv_heads)\n",
    "    v_output_oracle = repeat_kv(layer_activations['v_output'], n_heads // n_kv_heads)\n",
    "    inter_qk = layer_activations['inter_qk']\n",
    "    intra_qk_oracle = layer_activations['intra_qk']\n",
    "    hidden_state_oracle = layer_activations['h_output']\n",
    "    layer = model.model.layers[i]\n",
    "    if i < 1:\n",
    "        intra_qk = intra_qk_oracle\n",
    "        hidden_states = hidden_state_oracle\n",
    "        v_output = v_output_oracle\n",
    "    else:\n",
    "        # herror_rate = ((hidden_states - hidden_state_oracle).norm(p=2, dim=-1) / hidden_state_oracle.norm(p=2, dim=-1)).mean().item()\n",
    "        # print(f\"Hidden state Error rate: {herror_rate:.4f}\")\n",
    "        \n",
    "        hidden_state_norm = layer.input_layernorm(hidden_states)\n",
    "        out_seqlen = hidden_states.shape[1]  # Shape: [batch_size, seq_len, hidden_size]\n",
    "        in_seqlen = k_input.shape[2]\n",
    "        # position embedding\n",
    "        position_ids = torch.arange(in_seqlen, in_seqlen + out_seqlen).unsqueeze(0).to('cuda')\n",
    "        cos, sin = model.model.rotary_emb(hidden_states[0], position_ids)\n",
    "        \n",
    "        q = layer.self_attn.q_proj(hidden_state_norm)  # Query: [batch_size, out_seqlen, hidden_size]\n",
    "        k = layer.self_attn.k_proj(hidden_state_norm)  # Key: [batch_size, out_seqlen, hidden_size]\n",
    "        v = layer.self_attn.v_proj(hidden_state_norm)  # Value: [batch_size, out_seqlen, hidden_size]\n",
    "        \n",
    "        q = q.view(bsz, out_seqlen, n_heads, head_dim).transpose(1, 2)\n",
    "        k = k.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "        v_output = v.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "        v_output = repeat_kv(v_output, n_heads // n_kv_heads)\n",
    "        \n",
    "        qr, kr = apply_rotary_pos_emb(q, k, cos, sin)\n",
    "        kr = repeat_kv(kr, n_heads // n_kv_heads)\n",
    "        intra_qk = torch.matmul(qr, kr.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "        \n",
    "    # print(f\"======Layer {i}:\")\n",
    "    \n",
    "    utr_mask = torch.triu(torch.ones_like(intra_qk[0, 0, :, :], dtype=torch.bool) * torch.inf, diagonal=1)\n",
    "    total_evictable = 0\n",
    "    total_cache = 0\n",
    "    layer_error_rate = []\n",
    "    layer_compressibility = []\n",
    "    attn_output = []\n",
    "    for head in range(n_heads):\n",
    "        # print(f\"----Head {head}:\")\n",
    "        if i == 0:\n",
    "            evictable = torch.zeros(inter_qk.shape[-1], dtype=torch.bool, device=inter_qk.device)\n",
    "        elif window > 0:\n",
    "            pre_w_k_input = k_input[0, head, :-window, :]\n",
    "            post_w_k_input = k_input[0, head, -window:, :]\n",
    "            post_w_q_input = q_input[0, head, -window:, :]\n",
    "            w_inter_qk = torch.matmul(post_w_q_input, pre_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            w_intra_qk = torch.matmul(post_w_q_input, post_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            ltr_mask = torch.tril(torch.ones_like(w_intra_qk, dtype=torch.bool))\n",
    "            w_intra_qk_lower = w_intra_qk[ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(w_intra_qk_lower.flatten(), w_inter_qk[:, c]) for c in range(w_inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            mx_z = torch.nn.functional.relu(w_inter_qk.unsqueeze(-1) - w_intra_qk_lower.view(1, 1, -1))\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            if manual_cr is None:\n",
    "                mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "                evictable = intk_wise_ez < mx_ez\n",
    "            else:\n",
    "                k = int(inter_qk.shape[-1] * manual_cr)\n",
    "                _, botk_indices = torch.topk(intk_wise_ez, k, largest=False)\n",
    "                evictable = torch.zeros_like(intk_wise_ez, dtype=torch.bool)\n",
    "                evictable[botk_indices] = True\n",
    "            \n",
    "            evictable = torch.cat((evictable, torch.tensor([False]*window, device=evictable.device)))\n",
    "        else:\n",
    "            ltr_mask = torch.tril(torch.ones_like(intra_qk[0, head], dtype=torch.bool))\n",
    "            intra_qk_lower = intra_qk[0, head][ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(intra_qk_lower.flatten(), inter_qk[0, head, :, c]) for c in range(inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            \n",
    "            mx_z = torch.nn.functional.relu(inter_qk[0, head].unsqueeze(-1) - intra_qk_lower.view(1, 1, -1))\n",
    "            mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            evictable = intk_wise_ez < mx_ez\n",
    "        n_envictable = evictable.sum()\n",
    "        total_evictable += n_envictable\n",
    "        total_cache += inter_qk.shape[-1]\n",
    "        # print(f'evictable {n_envictable} out of {inter_qk.shape[-1]}; in percentage: {n_envictable/inter_qk.shape[-1]}')\n",
    "        layer_compressibility.append((n_envictable/inter_qk.shape[-1]).item())\n",
    "        \n",
    "        intra_qk_masked = intra_qk[0, head, :, :] - utr_mask\n",
    "        if n_envictable == 0:\n",
    "            v_combined = torch.cat((v_input[0, head, :, :], v_output[0, head, :, :]), dim=0)   # shape: (n+m, head_dim)\n",
    "            qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_masked), dim=-1) # shape: (m, n+m)\n",
    "        else:\n",
    "            v_evicted = v_input[0, head, evictable, :].mean(dim=0, keepdim=True)\n",
    "            v_remaining = v_input[0, head, ~evictable, :]\n",
    "            v_combined = torch.cat((v_evicted, v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (1+n'+m, head_dim)\n",
    "            # v_combined = torch.cat((v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (n'+m, head_dim)\n",
    "        \n",
    "            inter_qk_evicted = inter_qk[0, head, :, evictable].mean(dim=-1, keepdim=True)\n",
    "            # inter_qk_evicted = inter_qk[0, head, evictable, :].mean(dim=1).exp().sum(dim=0).log()\n",
    "\n",
    "            inter_qk_remaining = inter_qk[0, head, :, ~evictable]\n",
    "            qk_combined = torch.cat((inter_qk_evicted, inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, 1+n'+m)\n",
    "            # qk_combined = torch.cat((inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, n'+m)\n",
    "        resulting_v_output = torch.matmul(qk_combined.softmax(dim=-1), v_combined)\n",
    "        attn_output.append(resulting_v_output)\n",
    "        \n",
    "        oracle_v_combined = torch.cat((v_input[0, head, :, :], v_output_oracle[0, head, :, :]), dim=0)  # shape: (n+m, head_dim)\n",
    "        intra_qk_oracle_masked = intra_qk_oracle[0, head, :, :] - utr_mask\n",
    "        oracle_qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_oracle_masked), dim=-1) # shape: (m, n+m)\n",
    "        oracle_v_output = torch.matmul(oracle_qk_combined.softmax(dim=-1), oracle_v_combined)\n",
    "        error_rate = ((resulting_v_output - oracle_v_output).norm(p=2, dim=-1) / oracle_v_output.norm(p=2, dim=-1)).mean().item()\n",
    "        layer_error_rate.append(error_rate)\n",
    "        # print(f\"Error rate: {error_rate:.4f}\")\n",
    "        \n",
    "    attn_output = torch.stack(attn_output).unsqueeze(0)\n",
    "    bsz, num_heads, q_len, head_dim = attn_output.shape\n",
    "    attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "    attn_output = attn_output.reshape(bsz, q_len, model.config.hidden_size)\n",
    "    attn_output = layer.self_attn.o_proj(attn_output)\n",
    "    \n",
    "    hidden_states = hidden_states + attn_output  # residual connection\n",
    "    # Fully Connected\n",
    "    hidden_states_norm = layer.post_attention_layernorm(hidden_states)\n",
    "    mlp_output = layer.mlp(hidden_states_norm)\n",
    "    hidden_states = hidden_states + mlp_output  # residual connection\n",
    "    \n",
    "    \n",
    "    overall_compressibility.append(layer_compressibility)\n",
    "    overall_error_rate.append(layer_error_rate)\n",
    "    # print(f\"==total evictable in layer {i}: {total_evictable} out of {total_cache}; in percentage: {total_evictable/total_cache}\")\n",
    "    overall_evictable += total_evictable\n",
    "    overall_cache += total_cache\n",
    "# print(f\"Overall evictable: {overall_evictable} out of {overall_cache}; in percentage: {overall_evictable/overall_cache}\")\n",
    "overall_compressibility = np.array(overall_compressibility).T\n",
    "overall_error_rate = np.array(overall_error_rate).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f909e2d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(overall_error_rate[:,-1])*100, overall_error_rate[:,-1].std()*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e5cbfeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "overall_evictable = 0\n",
    "overall_cache = 0\n",
    "overall_compressibility = []\n",
    "overall_error_rate = []\n",
    "window = 64\n",
    "manual_cr = None\n",
    "kv_cache = []\n",
    "for i, layer_activations in enumerate(activations):\n",
    "    k_input = layer_activations['k_input']\n",
    "    k_output = layer_activations['k_output']\n",
    "    q_input = layer_activations['q_input']\n",
    "    q_output = layer_activations['q_output']\n",
    "    v_input = repeat_kv(layer_activations['v_input'], n_heads // n_kv_heads)\n",
    "    v_output_oracle = repeat_kv(layer_activations['v_output'], n_heads // n_kv_heads)\n",
    "    inter_qk = layer_activations['inter_qk']\n",
    "    intra_qk_oracle = layer_activations['intra_qk']\n",
    "    hidden_state_oracle = layer_activations['h_output']\n",
    "    layer = model.model.layers[i]\n",
    "    if i < 1:\n",
    "        intra_qk = intra_qk_oracle\n",
    "        hidden_states = hidden_state_oracle\n",
    "        v_output = v_output_oracle\n",
    "    else:\n",
    "        # herror_rate = ((hidden_states - hidden_state_oracle).norm(p=2, dim=-1) / hidden_state_oracle.norm(p=2, dim=-1)).mean().item()\n",
    "        # print(f\"Hidden state Error rate: {herror_rate:.4f}\")\n",
    "        \n",
    "        hidden_state_norm = layer.input_layernorm(hidden_states)\n",
    "        bsz = hidden_states.shape[0]\n",
    "        out_seqlen = hidden_states.shape[1]  # Shape: [batch_size, seq_len, hidden_size]\n",
    "        in_seqlen = k_input.shape[2]\n",
    "        # position embedding\n",
    "        position_ids = torch.arange(in_seqlen, in_seqlen + out_seqlen).unsqueeze(0).to('cuda')\n",
    "        cos, sin = model.model.rotary_emb(hidden_states[0], position_ids)\n",
    "        \n",
    "        q = layer.self_attn.q_proj(hidden_state_norm)  # Query: [batch_size, out_seqlen, hidden_size]\n",
    "        k = layer.self_attn.k_proj(hidden_state_norm)  # Key: [batch_size, out_seqlen, hidden_size]\n",
    "        v = layer.self_attn.v_proj(hidden_state_norm)  # Value: [batch_size, out_seqlen, hidden_size]\n",
    "        \n",
    "        q = q.view(bsz, out_seqlen, n_heads, head_dim).transpose(1, 2)\n",
    "        k = k.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "        v_output = v.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "        v_output = repeat_kv(v_output, n_heads // n_kv_heads)\n",
    "        \n",
    "        qr, kr = apply_rotary_pos_emb(q, k, cos, sin)\n",
    "        kr = repeat_kv(kr, n_heads // n_kv_heads)\n",
    "        intra_qk = torch.matmul(qr, kr.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "    \n",
    "    kv_cache.append([])\n",
    "    # print(f\"======Layer {i}:\")\n",
    "    \n",
    "    utr_mask = torch.triu(torch.ones_like(intra_qk[0, 0, :, :], dtype=torch.bool) * torch.inf, diagonal=1)\n",
    "    total_evictable = 0\n",
    "    total_cache = 0\n",
    "    layer_error_rate = []\n",
    "    layer_compressibility = []\n",
    "    attn_output = []\n",
    "    for head in range(n_heads):\n",
    "        # print(f\"----Head {head}:\")\n",
    "        if i == 0:\n",
    "            evictable = torch.zeros(inter_qk.shape[-1], dtype=torch.bool, device=inter_qk.device)\n",
    "        elif window > 0:\n",
    "            pre_w_k_input = k_input[0, head, :-window, :]\n",
    "            post_w_k_input = k_input[0, head, -window:, :]\n",
    "            post_w_q_input = q_input[0, head, -window:, :]\n",
    "            w_inter_qk = torch.matmul(post_w_q_input, pre_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            w_intra_qk = torch.matmul(post_w_q_input, post_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            ltr_mask = torch.tril(torch.ones_like(w_intra_qk, dtype=torch.bool))\n",
    "            w_intra_qk_lower = w_intra_qk[ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(w_intra_qk_lower.flatten(), w_inter_qk[:, c]) for c in range(w_inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            mx_z = torch.nn.functional.relu(w_inter_qk.unsqueeze(-1) - w_intra_qk_lower.view(1, 1, -1))\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            if manual_cr is None:\n",
    "                mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "                evictable = intk_wise_ez < mx_ez\n",
    "            else:\n",
    "                k = int(inter_qk.shape[-1] * manual_cr)\n",
    "                _, botk_indices = torch.topk(intk_wise_ez, k, largest=False)\n",
    "                evictable = torch.zeros_like(intk_wise_ez, dtype=torch.bool)\n",
    "                evictable[botk_indices] = True\n",
    "            \n",
    "            evictable = torch.cat((evictable, torch.tensor([False]*window, device=evictable.device)))\n",
    "        else:\n",
    "            ltr_mask = torch.tril(torch.ones_like(intra_qk[0, head], dtype=torch.bool))\n",
    "            intra_qk_lower = intra_qk[0, head][ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(intra_qk_lower.flatten(), inter_qk[0, head, :, c]) for c in range(inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            \n",
    "            mx_z = torch.nn.functional.relu(inter_qk[0, head].unsqueeze(-1) - intra_qk_lower.view(1, 1, -1))\n",
    "            mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            evictable = intk_wise_ez < mx_ez\n",
    "        # evictable[:] = False\n",
    "        n_envictable = evictable.sum()\n",
    "        total_evictable += n_envictable\n",
    "        total_cache += inter_qk.shape[-1]\n",
    "        # print(f'evictable {n_envictable} out of {inter_qk.shape[-1]}; in percentage: {n_envictable/inter_qk.shape[-1]}')\n",
    "        layer_compressibility.append((n_envictable/inter_qk.shape[-1]).item())\n",
    "        \n",
    "        v_remaining = v_input[0, head, ~evictable, :]\n",
    "        k_remaining = v_input[0, head, ~evictable, :]\n",
    "        kv_cache[-1].append((k_remaining, v_remaining))\n",
    "        \n",
    "        intra_qk_masked = intra_qk[0, head, :, :] - utr_mask\n",
    "        if n_envictable == 0:\n",
    "            v_combined = torch.cat((v_input[0, head, :, :], v_output[0, head, :, :]), dim=0)   # shape: (n+m, head_dim)\n",
    "            qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_masked), dim=-1) # shape: (m, n+m)\n",
    "        else:\n",
    "            v_evicted = v_input[0, head, evictable, :].mean(dim=0, keepdim=True)\n",
    "            \n",
    "            # v_combined = torch.cat((v_evicted, v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (1+n'+m, head_dim)\n",
    "            v_combined = torch.cat((v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (n'+m, head_dim)\n",
    "        \n",
    "            inter_qk_evicted = inter_qk[0, head, :, evictable].mean(dim=-1, keepdim=True)\n",
    "            # inter_qk_evicted = inter_qk[0, head, evictable, :].mean(dim=1).exp().sum(dim=0).log()\n",
    "\n",
    "            inter_qk_remaining = inter_qk[0, head, :, ~evictable]\n",
    "            # qk_combined = torch.cat((inter_qk_evicted, inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, 1+n'+m)\n",
    "            qk_combined = torch.cat((inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, n'+m)\n",
    "        \n",
    "        resulting_v_output = torch.matmul(qk_combined.softmax(dim=-1), v_combined)\n",
    "        attn_output.append(resulting_v_output)\n",
    "        \n",
    "        oracle_v_combined = torch.cat((v_input[0, head, :, :], v_output_oracle[0, head, :, :]), dim=0)  # shape: (n+m, head_dim)\n",
    "        intra_qk_oracle_masked = intra_qk_oracle[0, head, :, :] - utr_mask\n",
    "        oracle_qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_oracle_masked), dim=-1) # shape: (m, n+m)\n",
    "        oracle_v_output = torch.matmul(oracle_qk_combined.softmax(dim=-1), oracle_v_combined)\n",
    "        error_rate = ((resulting_v_output - oracle_v_output).norm(p=2, dim=-1) / oracle_v_output.norm(p=2, dim=-1)).mean().item()\n",
    "        layer_error_rate.append(error_rate)\n",
    "        # print(f\"Error rate: {error_rate:.4f}\")\n",
    "        \n",
    "    attn_output = torch.stack(attn_output).unsqueeze(0)\n",
    "    bsz, num_heads, q_len, head_dim = attn_output.shape\n",
    "    attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "    attn_output = attn_output.reshape(bsz, q_len, model.config.hidden_size)\n",
    "    attn_output = layer.self_attn.o_proj(attn_output)\n",
    "    \n",
    "    hidden_states = hidden_states + attn_output  # residual connection\n",
    "    # Fully Connected\n",
    "    hidden_states_norm = layer.post_attention_layernorm(hidden_states)\n",
    "    mlp_output = layer.mlp(hidden_states_norm)\n",
    "    hidden_states = hidden_states + mlp_output  # residual connection\n",
    "    \n",
    "    \n",
    "    overall_compressibility.append(layer_compressibility)\n",
    "    overall_error_rate.append(layer_error_rate)\n",
    "    # print(f\"==total evictable in layer {i}: {total_evictable} out of {total_cache}; in percentage: {total_evictable/total_cache}\")\n",
    "    overall_evictable += total_evictable\n",
    "    overall_cache += total_cache\n",
    "# print(f\"Overall evictable: {overall_evictable} out of {overall_cache}; in percentage: {overall_evictable/overall_cache}\")\n",
    "overall_compressibility = np.array(overall_compressibility).T\n",
    "overall_error_rate = np.array(overall_error_rate).T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea89dfb",
   "metadata": {},
   "source": [
    "# With Error Propagation and autoregressive generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "aeaa1dac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "old_gt_ids.shape torch.Size([1, 213])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:04,  4.94s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n\\nHere is a concise summary of the meeting transcript:\\n\\n**Meeting Summary**\\n\\nThe meeting discussed the design of a new TV with an LCD panel and jogdial. The key points are:\\n\\n* The LCD panel will display simple functions, such as brightness and contrast, with minimal text.\\n* The number pad will be used for volume control, with the jogdial used for advanced functions like contrast and color.\\n* The TV will have a prominent power button that can be held for 2 seconds to send a standby signal.\\n* The TV will have basic audio settings, including bass, treble, and balance, with preprogrammed sound modes.\\n* The TV will have multiple inputs, including VCR and other options.\\n* The design team discussed the dimensions of the TV, with a height of around 10 cm and a width of around 3.5 cm.\\n* The colors used will be similar to those used in the previous design, with a foggy yellow for the buttons.\\n\\n**Action Items**\\n\\n* The project manager will write a report summarizing'"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gt_ids.shape torch.Size([1, 213])\n",
      "old_gt_ids: tensor([[  271,  8586,   596,   264, 12399,   315,   279,  6574, 36815,  1473]],\n",
      "       device='cuda:0')\n",
      "gt_ids    : tensor([[  271,  8586,   374,   264, 64694, 12399,   315,   279,  6574, 36815]],\n",
      "       device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "for sample_i, sample_batch in tqdm(enumerate(valid_dataloader)):\n",
    "    if sample_i == 1:\n",
    "        break\n",
    "    old_gt_ids = sample_batch['input_ids'][sample_batch['labels'] != -100].unsqueeze(0).to(model.device)\n",
    "    print('old_gt_ids.shape', old_gt_ids.shape)\n",
    "    prompt_ids = sample_batch['input_ids'][sample_batch['labels'] == -100].unsqueeze(0).to(model.device) \n",
    "\n",
    "    ptgt_ids = model.generate(\n",
    "        prompt_ids,\n",
    "        do_sample=False,         # Deterministic (greedy) decoding\n",
    "        max_new_tokens=old_gt_ids.shape[1],      # Number of tokens to generate *after* the prompt\n",
    "        temperature=None,        # No randomness\n",
    "        top_p=None,           # No top-p sampling\n",
    "    )\n",
    "\n",
    "    gt_ids = ptgt_ids[:, len(prompt_ids[0]):]\n",
    "    all_text = tokenizer.decode(ptgt_ids[0])\n",
    "    gt_text = tokenizer.decode(gt_ids[0])\n",
    "\n",
    "display(gt_text)\n",
    "print('gt_ids.shape', gt_ids.shape)\n",
    "print('old_gt_ids:', old_gt_ids[:, :10])\n",
    "print('gt_ids    :', gt_ids[:, :10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4a1f0f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_out_ids(activations, speculative_ip_ids, max_gen_tks=10, window=None, f=1, use_evictables=None):\n",
    "    model.eval()\n",
    "    head_dim = model.config.head_dim\n",
    "    n_heads = model.config.num_attention_heads\n",
    "    n_kv_heads = model.config.num_key_value_heads\n",
    "    overall_evictable = 0\n",
    "    overall_cache = 0\n",
    "    overall_compressibility = []\n",
    "    overall_error_rate = []\n",
    "    manual_cr = None\n",
    "    all_evictables = [] if use_evictables is None else use_evictables\n",
    "    for i, layer_activations in enumerate(activations):\n",
    "        k_input = layer_activations['k_input']\n",
    "        # k_output = layer_activations['k_output']\n",
    "        q_input = layer_activations['q_input']\n",
    "        # q_output = layer_activations['q_output'][:, :, :max_gen_tks, :]\n",
    "        v_input = repeat_kv(layer_activations['v_input'], n_heads // n_kv_heads)\n",
    "        v_output_oracle = repeat_kv(layer_activations['v_output'], n_heads // n_kv_heads)[:, :, :max_gen_tks, :]\n",
    "        inter_qk = layer_activations['inter_qk']\n",
    "        intra_qk_oracle = layer_activations['intra_qk'][:, :, :max_gen_tks, :max_gen_tks]\n",
    "        # hidden_state_oracle = layer_activations['h_output'][:, :max_gen_tks, :]\n",
    "        layer = model.model.layers[i]\n",
    "        \n",
    "        if i < 1:\n",
    "            hidden_states = model.model.embed_tokens(speculative_ip_ids)\n",
    "            # hidden_states = hidden_state_oracle\n",
    "        if True:\n",
    "            # herror_rate = ((hidden_states - hidden_state_oracle).norm(p=2, dim=-1) / hidden_state_oracle.norm(p=2, dim=-1)).mean().item()\n",
    "            # print(f\"======Layer {i}: Hidden state Error rate: {herror_rate:.4f}\")\n",
    "            \n",
    "            hidden_state_norm = layer.input_layernorm(hidden_states)\n",
    "            bsz = hidden_states.shape[0]\n",
    "            out_seqlen = hidden_states.shape[1]  # Shape: [batch_size, seq_len, hidden_size]\n",
    "            in_seqlen = k_input.shape[2]\n",
    "            # position embedding\n",
    "            position_ids = torch.arange(in_seqlen, in_seqlen + out_seqlen).unsqueeze(0).to(model.device)\n",
    "            cos, sin = model.model.rotary_emb(hidden_states[0], position_ids)\n",
    "            \n",
    "            q = layer.self_attn.q_proj(hidden_state_norm)  # Query: [batch_size, out_seqlen, hidden_size]\n",
    "            k = layer.self_attn.k_proj(hidden_state_norm)  # Key: [batch_size, out_seqlen, hidden_size]\n",
    "            v = layer.self_attn.v_proj(hidden_state_norm)  # Value: [batch_size, out_seqlen, hidden_size]\n",
    "            \n",
    "            q = q.view(bsz, out_seqlen, n_heads, head_dim).transpose(1, 2)\n",
    "            k = k.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "            v_output = v.view(bsz, out_seqlen, n_kv_heads, head_dim).transpose(1, 2)\n",
    "            v_output = repeat_kv(v_output, n_heads // n_kv_heads)\n",
    "            \n",
    "            qr, kr = apply_rotary_pos_emb(q, k, cos, sin)\n",
    "            kr = repeat_kv(kr, n_heads // n_kv_heads)\n",
    "            \n",
    "            intra_qk = torch.matmul(qr, kr.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "        \n",
    "        inter_qk = torch.matmul(qr, k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            \n",
    "        # print(f\"======Layer {i}:\")\n",
    "        \n",
    "        utr_mask = torch.triu(torch.ones_like(intra_qk[0, 0, :, :], dtype=torch.bool) * torch.inf, diagonal=1)\n",
    "        total_evictable = 0\n",
    "        total_cache = 0\n",
    "        layer_error_rate = []\n",
    "        layer_compressibility = []\n",
    "        attn_output = []\n",
    "        if use_evictables is None:\n",
    "            layer_evictables = []\n",
    "        else:\n",
    "            layer_evictables = use_evictables[i]\n",
    "        for head in range(n_heads):\n",
    "            # print(f\"----Head {head}:\")\n",
    "            if use_evictables is not None:\n",
    "                evictable = layer_evictables[head]\n",
    "            else:\n",
    "                if (i == 0) or (window == None):\n",
    "                    evictable = torch.zeros(inter_qk.shape[-1], dtype=torch.bool, device=inter_qk.device)\n",
    "                elif window > 0:\n",
    "                    pre_w_k_input = k_input[0, head, :-window, :]\n",
    "                    post_w_k_input = k_input[0, head, -window:, :]\n",
    "                    post_w_q_input = q_input[0, head, -window:, :]\n",
    "                    w_inter_qk = torch.matmul(post_w_q_input, pre_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "                    w_intra_qk = torch.matmul(post_w_q_input, post_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "                    ltr_mask = torch.tril(torch.ones_like(w_intra_qk, dtype=torch.bool))\n",
    "                    w_intra_qk_lower = w_intra_qk[ltr_mask]\n",
    "                    \n",
    "                    mx_z = torch.nn.functional.relu(w_inter_qk.unsqueeze(-1) - w_intra_qk_lower.view(1, 1, -1))\n",
    "                    intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "                    if manual_cr is None:\n",
    "                        mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "                        evictable = intk_wise_ez < f * mx_ez\n",
    "                    else:\n",
    "                        k = int(inter_qk.shape[-1] * manual_cr)\n",
    "                        _, botk_indices = torch.topk(intk_wise_ez, k, largest=False)\n",
    "                        evictable = torch.zeros_like(intk_wise_ez, dtype=torch.bool)\n",
    "                        evictable[botk_indices] = True\n",
    "                    \n",
    "                    evictable = torch.cat((evictable, torch.tensor([False]*window, device=evictable.device)))\n",
    "                else:\n",
    "                    ltr_mask = torch.tril(torch.ones_like(intra_qk[0, head], dtype=torch.bool))\n",
    "                    intra_qk_lower = intra_qk[0, head][ltr_mask]\n",
    "\n",
    "                    mx_z = torch.nn.functional.relu(inter_qk[0, head].unsqueeze(-1) - intra_qk_lower.view(1, 1, -1))\n",
    "                    mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "                    intk_wise_ez = torch.nn.functional.relu(mx_z).mean(dim=[0, 2])\n",
    "                    evictable = intk_wise_ez < mx_ez\n",
    "                layer_evictables.append(evictable)\n",
    "            n_envictable = evictable.sum()\n",
    "            total_evictable += n_envictable\n",
    "            total_cache += inter_qk.shape[-1]\n",
    "            # print(f'evictable {n_envictable} out of {inter_qk.shape[-1]}; in percentage: {n_envictable/inter_qk.shape[-1]}')\n",
    "            layer_compressibility.append((n_envictable/inter_qk.shape[-1]).item())\n",
    "            \n",
    "            intra_qk_masked = intra_qk[0, head, :, :] - utr_mask\n",
    "            if n_envictable == 0:\n",
    "                v_combined = torch.cat((v_input[0, head, :, :], v_output[0, head, :, :]), dim=0)   # shape: (n+m, head_dim)\n",
    "                qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_masked), dim=-1) # shape: (m, n+m)\n",
    "            else:\n",
    "                v_evicted = v_input[0, head, evictable, :].mean(dim=0, keepdim=True)\n",
    "                v_remaining = v_input[0, head, ~evictable, :]\n",
    "                v_combined = torch.cat((v_evicted, v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (1+n'+m, head_dim)\n",
    "                # v_combined = torch.cat((v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (n'+m, head_dim)\n",
    "            \n",
    "                inter_qk_evicted = inter_qk[0, head, :, evictable].mean(dim=-1, keepdim=True)\n",
    "                # inter_qk_evicted = inter_qk[0, head, evictable, :].mean(dim=1).exp().sum(dim=0).log()\n",
    "\n",
    "                inter_qk_remaining = inter_qk[0, head, :, ~evictable]\n",
    "                qk_combined = torch.cat((inter_qk_evicted, inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, 1+n'+m)\n",
    "                # qk_combined = torch.cat((inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, n'+m)\n",
    "            resulting_v_output = torch.matmul(qk_combined.softmax(dim=-1), v_combined)\n",
    "            attn_output.append(resulting_v_output)\n",
    "            \n",
    "            oracle_v_combined = torch.cat((v_input[0, head, :, :], v_output_oracle[0, head, :, :]), dim=0)  # shape: (n+m, head_dim)\n",
    "            intra_qk_oracle_masked = intra_qk_oracle[0, head, :, :] - utr_mask\n",
    "            oracle_qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_oracle_masked), dim=-1) # shape: (m, n+m)\n",
    "            oracle_v_output = torch.matmul(oracle_qk_combined.softmax(dim=-1), oracle_v_combined)\n",
    "            error_rate = ((resulting_v_output - oracle_v_output).norm(p=2, dim=-1) / oracle_v_output.norm(p=2, dim=-1)).mean().item()\n",
    "            layer_error_rate.append(error_rate)\n",
    "            # print(f\"Error rate: {error_rate:.4f}\")\n",
    "        \n",
    "        if use_evictables is None:\n",
    "            all_evictables.append(layer_evictables)    \n",
    "        attn_output = torch.stack(attn_output).unsqueeze(0)\n",
    "        bsz, num_heads, q_len, head_dim = attn_output.shape\n",
    "        attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "        attn_output = attn_output.reshape(bsz, q_len, model.config.hidden_size)\n",
    "        attn_output = layer.self_attn.o_proj(attn_output)\n",
    "        \n",
    "        hidden_states = hidden_states + attn_output  # residual connection\n",
    "        # Fully Connected\n",
    "        hidden_states_norm = layer.post_attention_layernorm(hidden_states)\n",
    "        mlp_output = layer.mlp(hidden_states_norm)\n",
    "        hidden_states = hidden_states + mlp_output  # residual connection\n",
    "        \n",
    "        \n",
    "        overall_compressibility.append(layer_compressibility)\n",
    "        overall_error_rate.append(layer_error_rate)\n",
    "        # print(f\"==total evictable in layer {i}: {total_evictable} out of {total_cache}; in percentage: {total_evictable/total_cache}\")\n",
    "        overall_evictable += total_evictable\n",
    "        overall_cache += total_cache\n",
    "    # print(f\"Overall evictable: {overall_evictable} out of {overall_cache}; in percentage: {overall_evictable/overall_cache}\")\n",
    "    overall_compressibility = np.array(overall_compressibility).T\n",
    "    overall_error_rate = np.array(overall_error_rate).T\n",
    "    return model.lm_head(hidden_states).argmax(dim=-1), all_evictables # , overall_error_rate, model_out_ids, model.\n",
    "\n",
    "\n",
    "def get_ar_model_out_ids(activations, input_ids, max_gen_tks=10, window=None, f=1):\n",
    "    use_evictables = None\n",
    "    for i in range(1, max_gen_tks):\n",
    "        model_out_ids, use_evictables = get_model_out_ids(activations, input_ids, \n",
    "                                                          max_gen_tks=input_ids.shape[1], window=window, \n",
    "                                                          f=f, use_evictables=use_evictables)\n",
    "        if model_out_ids[0, -1] == tokenizer.eos_token_id:\n",
    "            break\n",
    "        # print(model_out_ids[0].cpu())\n",
    "        input_ids = torch.cat((input_ids, model_out_ids[:, [-1]]), dim=1)\n",
    "    our_gt_ids = model_out_ids\n",
    "    ov_comp = (torch.stack([torch.stack([head_ev for head_ev in layer_ev]) for layer_ev in use_evictables]) * 1.0).mean()\n",
    "    return our_gt_ids, ov_comp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2cc1afaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "old_gt_ids.shape torch.Size([1, 213])\n",
      "gt_ids torch.Size([1, 213])\n",
      "no_comp_op_ids: torch.Size([1, 212])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [03:09, 189.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "comp_op_ids: torch.Size([1, 212])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "all_old_gt_ids = []\n",
    "all_gt_ids = []\n",
    "all_no_comp_op_ids = []\n",
    "all_comp_op_ids = []\n",
    "all_comp_r = []\n",
    "for sample_i, sample_batch in tqdm(enumerate(valid_dataloader)):\n",
    "    if sample_i == 1:\n",
    "        break\n",
    "    old_gt_ids = sample_batch['input_ids'][sample_batch['labels'] != -100].unsqueeze(0).to(model.device)\n",
    "    all_old_gt_ids.append(old_gt_ids)\n",
    "    print('old_gt_ids.shape', old_gt_ids.shape)\n",
    "    prompt_ids = sample_batch['input_ids'][sample_batch['labels'] == -100].unsqueeze(0).to(model.device) \n",
    "\n",
    "    max_gen_tks = old_gt_ids.shape[1]\n",
    "    ptgt_ids = model.generate(\n",
    "        prompt_ids,\n",
    "        do_sample=False,         # Deterministic (greedy) decoding\n",
    "        max_new_tokens=max_gen_tks,      # Number of tokens to generate *after* the prompt\n",
    "        temperature=None,        # No randomness\n",
    "        top_p=None,           # No top-p sampling\n",
    "    )\n",
    "    gt_ids = ptgt_ids[:, len(prompt_ids[0]):]\n",
    "    assert old_gt_ids[:, 0] == gt_ids[:, 0]\n",
    "    print('gt_ids', gt_ids.shape)\n",
    "    all_gt_ids.append(gt_ids)\n",
    "    \n",
    "    activations, overlap_areas = get_info(sample_batch, return_activations=True)\n",
    "    no_comp_op_ids, _ = get_ar_model_out_ids(activations, gt_ids[:, [0]], max_gen_tks=gt_ids.shape[1], window=None, f=1)\n",
    "    all_no_comp_op_ids.append(torch.cat([gt_ids[:, [0]], no_comp_op_ids], dim=1))\n",
    "    print('no_comp_op_ids:', no_comp_op_ids.shape)\n",
    "    \n",
    "    comp_op_ids, cr = get_ar_model_out_ids(activations, gt_ids[:, [0]], max_gen_tks=gt_ids.shape[1], window=16, f=1)\n",
    "    all_comp_op_ids.append(torch.cat([gt_ids[:, [0]], comp_op_ids], dim=1))\n",
    "    all_comp_r.append(cr)\n",
    "    print('comp_op_ids:', comp_op_ids.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "effb0412",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no coomp: rouge scores: {'rouge1': Score(precision=0.5921052631578947, recall=0.6040268456375839, fmeasure=0.5980066445182725), 'rougeL': Score(precision=0.3881578947368421, recall=0.3959731543624161, fmeasure=0.3920265780730897)}\n",
      "ar unsloth: rouge scores: {'rouge1': Score(precision=0.5592105263157895, recall=0.5120481927710844, fmeasure=0.5345911949685535), 'rougeL': Score(precision=0.375, recall=0.3433734939759036, fmeasure=0.3584905660377358)}\n",
      "data unsloth: rouge scores: {'rouge1': Score(precision=0.5526315789473685, recall=0.49411764705882355, fmeasure=0.5217391304347826), 'rougeL': Score(precision=0.3618421052631579, recall=0.3235294117647059, fmeasure=0.3416149068322981)}\n"
     ]
    }
   ],
   "source": [
    "from rouge_score import rouge_scorer\n",
    "\n",
    "# Create scorer\n",
    "scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)\n",
    "\n",
    "# Compute scores\n",
    "scores = scorer.score(tokenizer.decode(all_no_comp_op_ids[0][0]), tokenizer.decode(all_comp_op_ids[0][0]))\n",
    "print('no coomp: rouge scores:', scores)\n",
    "scores = scorer.score(tokenizer.decode(all_gt_ids[0][0]), tokenizer.decode(all_comp_op_ids[0][0]))\n",
    "print('ar unsloth: rouge scores:', scores)\n",
    "scores = scorer.score(tokenizer.decode(all_old_gt_ids[0][0]), tokenizer.decode(all_comp_op_ids[0][0]))\n",
    "print('data unsloth: rouge scores:', scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "01a090f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no coomp: rouge scores: {'rouge1': Score(precision=0.45410628019323673, recall=0.6308724832214765, fmeasure=0.5280898876404494), 'rougeL': Score(precision=0.26570048309178745, recall=0.3691275167785235, fmeasure=0.3089887640449438)}\n",
      "ar unsloth: rouge scores: {'rouge1': Score(precision=0.4444444444444444, recall=0.5542168674698795, fmeasure=0.4932975871313673), 'rougeL': Score(precision=0.26570048309178745, recall=0.3313253012048193, fmeasure=0.29490616621983917)}\n",
      "data unsloth: rouge scores: {'rouge1': Score(precision=0.45410628019323673, recall=0.5529411764705883, fmeasure=0.4986737400530504), 'rougeL': Score(precision=0.2608695652173913, recall=0.3176470588235294, fmeasure=0.286472148541114)}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "output_file = 'Results_outputs/unsloth_Llama-3.2-3B-Instruct/0_6/qmsum_responses_TOVAPress_0_6_unsloth_Llama-3.2-3B-Instruct.json'\n",
    "\n",
    "with open(output_file, 'r') as f:\n",
    "    log_data_baseline = json.load(f)\n",
    "\n",
    "baseline_response = log_data_baseline[0]['response']\n",
    "# Compute scores\n",
    "scores = scorer.score(tokenizer.decode(all_no_comp_op_ids[0][0]), baseline_response)\n",
    "print('no coomp: rouge scores:', scores)\n",
    "scores = scorer.score(tokenizer.decode(all_gt_ids[0][0]), baseline_response)\n",
    "print('ar unsloth: rouge scores:', scores)\n",
    "scores = scorer.score(tokenizer.decode(all_old_gt_ids[0][0]), baseline_response)\n",
    "print('data unsloth: rouge scores:', scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f370742",
   "metadata": {},
   "source": [
    "## Wihtout Error Propagation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f4ddad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_evictable = 0\n",
    "overall_cache = 0\n",
    "wop_overall_compressibility = []\n",
    "wop_overall_error_rate = []\n",
    "window = 64\n",
    "manual_cr = None\n",
    "for i, layer_activations in enumerate(activations):\n",
    "    k_input = layer_activations['k_input']\n",
    "    k_output = layer_activations['k_output']\n",
    "    q_input = layer_activations['q_input']\n",
    "    q_output = layer_activations['q_output']\n",
    "    v_input = repeat_kv(layer_activations['v_input'], n_heads // n_kv_heads)\n",
    "    v_output = repeat_kv(layer_activations['v_output'], n_heads // n_kv_heads)\n",
    "    inter_qk = layer_activations['inter_qk']\n",
    "    intra_qk = layer_activations['intra_qk']\n",
    "    # print(f\"======Layer {i}:\")\n",
    "    \n",
    "    utr_mask = torch.triu(torch.ones_like(intra_qk[0, 0, :, :], dtype=torch.bool) * torch.inf, diagonal=1)\n",
    "    total_evictable = 0\n",
    "    total_cache = 0\n",
    "    layer_error_rate = []\n",
    "    layer_compressibility = []\n",
    "    for head in range(n_heads):\n",
    "        # print(f\"----Head {head}:\")\n",
    "        if i == 0:\n",
    "            evictable = torch.zeros(inter_qk.shape[-1], dtype=torch.bool, device=inter_qk.device)\n",
    "        elif window > 0:\n",
    "            pre_w_k_input = k_input[0, head, :-window, :]\n",
    "            post_w_k_input = k_input[0, head, -window:, :]\n",
    "            post_w_q_input = q_input[0, head, -window:, :]\n",
    "            w_inter_qk = torch.matmul(post_w_q_input, pre_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            w_intra_qk = torch.matmul(post_w_q_input, post_w_k_input.transpose(-1, -2)) / math.sqrt(head_dim)\n",
    "            ltr_mask = torch.tril(torch.ones_like(w_intra_qk, dtype=torch.bool))\n",
    "            w_intra_qk_lower = w_intra_qk[ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(w_intra_qk_lower.flatten(), w_inter_qk[:, c]) for c in range(w_inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            mx_z = torch.nn.functional.relu(w_inter_qk.unsqueeze(-1) - w_intra_qk_lower.view(1, 1, -1))\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            if manual_cr is None:\n",
    "                mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "                evictable = intk_wise_ez < mx_ez\n",
    "            else:\n",
    "                k = int(inter_qk.shape[-1] * manual_cr)\n",
    "                _, botk_indices = torch.topk(intk_wise_ez, k, largest=False)\n",
    "                evictable = torch.zeros_like(intk_wise_ez, dtype=torch.bool)\n",
    "                evictable[botk_indices] = True\n",
    "            \n",
    "            evictable = torch.cat((evictable, torch.tensor([False]*window, device=evictable.device)))\n",
    "        else:\n",
    "            ltr_mask = torch.tril(torch.ones_like(intra_qk[0, head], dtype=torch.bool))\n",
    "            intra_qk_lower = intra_qk[0, head][ltr_mask]\n",
    "            # intk_wise_oa = [compute_overlap_area_torch(intra_qk_lower.flatten(), inter_qk[0, head, :, c]) for c in range(inter_qk.shape[-1])]\n",
    "            # evictable = torch.tensor(intk_wise_oa, device=inter_qk.device) < overlap_areas[i, head]\n",
    "            \n",
    "            mx_z = torch.nn.functional.relu(inter_qk[0, head].unsqueeze(-1) - intra_qk_lower.view(1, 1, -1))\n",
    "            mx_ez = mx_z.mean()   # mx_z[mx_z > 0].sum() / mx_z.numel()).item()\n",
    "            intk_wise_ez = mx_z.mean(dim=[0, 2])\n",
    "            evictable = intk_wise_ez < mx_ez\n",
    "        n_envictable = evictable.sum()\n",
    "        total_evictable += n_envictable\n",
    "        total_cache += inter_qk.shape[-1]\n",
    "        # print(f'evictable {n_envictable} out of {inter_qk.shape[-1]}; in percentage: {n_envictable/inter_qk.shape[-1]}')\n",
    "        layer_compressibility.append((n_envictable/inter_qk.shape[-1]).item())\n",
    "        \n",
    "        intra_qk_masked = intra_qk[0, head, :, :] - utr_mask\n",
    "        if n_envictable == 0:\n",
    "            v_combined = torch.cat((v_input[0, head, :, :], v_output[0, head, :, :]), dim=0)   # shape: (n+m, head_dim)\n",
    "            qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_masked), dim=-1) # shape: (m, n+m)\n",
    "        else:\n",
    "            v_evicted = v_input[0, head, evictable, :].mean(dim=0, keepdim=True)\n",
    "            v_remaining = v_input[0, head, ~evictable, :]\n",
    "            v_combined = torch.cat((v_evicted, v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (1+n'+m, head_dim)\n",
    "            # v_combined = torch.cat((v_remaining, v_output[0, head, :, :]), dim=0)   # shape: (n'+m, head_dim)\n",
    "            inter_qk_evicted = inter_qk[0, head, :, evictable].mean(dim=-1, keepdim=True)\n",
    "            # inter_qk_evicted = inter_qk[0, head, evictable, :].mean(dim=1).exp().sum(dim=0).log()\n",
    "\n",
    "            inter_qk_remaining = inter_qk[0, head, :, ~evictable]\n",
    "            qk_combined = torch.cat((inter_qk_evicted, inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, 1+n'+m)\n",
    "            # qk_combined = torch.cat((inter_qk_remaining, intra_qk_masked), dim=-1) # shape: (m, n'+m)\n",
    "        \n",
    "        resulting_v_output = torch.matmul(qk_combined.softmax(dim=-1), v_combined)\n",
    "        \n",
    "        oracle_v_combined = torch.cat((v_input[0, head, :, :], v_output[0, head, :, :]), dim=0)  # shape: (n+m, head_dim)\n",
    "        oracle_qk_combined = torch.cat((inter_qk[0, head, :, :], intra_qk_masked), dim=-1) # shape: (m, n+m)\n",
    "        oracle_v_output = torch.matmul(oracle_qk_combined.softmax(dim=-1), oracle_v_combined)\n",
    "        error_rate = ((resulting_v_output - oracle_v_output).norm(p=2, dim=-1) / oracle_v_output.norm(p=2, dim=-1)).mean().item()\n",
    "        layer_error_rate.append(error_rate)\n",
    "        # print(f\"Error rate: {error_rate:.4f}\")\n",
    "        \n",
    "    wop_overall_compressibility.append(layer_compressibility)\n",
    "    wop_overall_error_rate.append(layer_error_rate)\n",
    "    # print(f\"==total evictable in layer {i}: {total_evictable} out of {total_cache}; in percentage: {total_evictable/total_cache}\")\n",
    "    overall_evictable += total_evictable\n",
    "    overall_cache += total_cache\n",
    "# print(f\"Overall evictable: {overall_evictable} out of {overall_cache}; in percentage: {overall_evictable/overall_cache}\")\n",
    "wop_overall_compressibility = np.array(wop_overall_compressibility).T\n",
    "wop_overall_error_rate = np.array(wop_overall_error_rate).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b1fd294",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "press",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
