{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_real_share.modeling_llama_kvsharer import LlamaForCausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_path = 'YOUR MODEL'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llama = LlamaForCausalLM.from_pretrained(llama_path, device_map='auto')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Calibration Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_data_path = './data/wiki_demo.txt'\n",
    "with open(wiki_data_path, 'r') as f:\n",
    "    wiki_data = f.readlines()\n",
    "    f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calibration_set = wiki_data[0:30]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate the Euclidean Distance between any two layers of KV cache and sort them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import torch\n",
    "\n",
    "kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}\n",
    "kv_cache_list = []\n",
    "with torch.no_grad():\n",
    "    for text in tqdm(calibration_set):\n",
    "        inp = tokenizer(text, return_tensors='pt', max_length=64, truncation=True)\n",
    "        inp = inp.to('cuda:0')\n",
    "        out = llama(**inp, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
    "        past_key_values = out.past_key_values\n",
    "        kv_cache_list.append(past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers = len(kv_cache_list[0])\n",
    "avg_past_key_values = [(torch.zeros_like(kv_cache_list[0][i][0]), torch.zeros_like(kv_cache_list[0][i][1])) for i in range(num_layers)]\n",
    "\n",
    "for past_key_values in tqdm(kv_cache_list):\n",
    "    for i, (key, value) in enumerate(past_key_values):\n",
    "        try:\n",
    "            avg_past_key_values[i] = (avg_past_key_values[i][0] + key, avg_past_key_values[i][1] + value)\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "num_elements = len(kv_cache_list)\n",
    "avg_past_key_values = [(key / num_elements, value / num_elements) for key, value in avg_past_key_values]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "def compute_cosine_similarity(tensor1, tensor2):\n",
    "    return F.cosine_similarity(tensor1.flatten(1), tensor2.flatten(1), dim=-1).mean().item()\n",
    "\n",
    "def compute_euclidean_distance(tensor1, tensor2):\n",
    "    return torch.norm(tensor1 - tensor2, p=2, dim=-1).mean().item()\n",
    "\n",
    "num_layers = len(avg_past_key_values)\n",
    "similarity_matrix = np.zeros((num_layers, num_layers))\n",
    "\n",
    "for i in range(num_layers):\n",
    "    for j in range(num_layers):\n",
    "        if i > j:\n",
    "            key_i, value_i = avg_past_key_values[i]\n",
    "            key_j, value_j = avg_past_key_values[j]\n",
    "            key_similarity = compute_euclidean_distance(key_i, key_j)\n",
    "            value_similarity = compute_euclidean_distance(value_i, value_j)  \n",
    "            similarity_matrix[i, j] = (key_similarity + value_similarity) / 2\n",
    "        else:\n",
    "            similarity_matrix[i, j] = np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "flattened_values = similarity_matrix.flatten()\n",
    "valid_indices = ~np.isnan(flattened_values)\n",
    "\n",
    "valid_values = flattened_values[valid_indices]\n",
    "valid_flat_indices = np.where(valid_indices)[0]\n",
    "\n",
    "sorted_valid_indices = np.argsort(valid_values)[::-1]\n",
    "sorted_flat_indices = valid_flat_indices[sorted_valid_indices]\n",
    "\n",
    "sorted_positions = np.unravel_index(sorted_flat_indices, similarity_matrix.shape)\n",
    "\n",
    "pos_rank = []\n",
    "\n",
    "for i in range(sorted_positions[0].shape[0]):\n",
    "    pos = (sorted_positions[0][i], sorted_positions[1][i])\n",
    "    pos_rank.append(pos)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the Sharing Layers and THRESHOLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SHARE_LAYERS = 4\n",
    "THRESHOLD = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def cal_last_hidden_sim(model1, model2, kv_cache_share_layers_map, tokenizer, sents):\n",
    "    sim_ls = []\n",
    "    for s in sents:\n",
    "        encoded_inputs = tokenizer(s, max_length=64, truncation=True, return_tensors='pt')\n",
    "        encoded_inputs.to('cuda:0')\n",
    "        with torch.no_grad():\n",
    "            outputs1 = model1(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map={i:i for i in range(len(model1.model.layers))})\n",
    "        hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        with torch.no_grad():\n",
    "            outputs2 = model2(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
    "        hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)\n",
    "        sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))\n",
    "    sim_ls = [i.item() for i in sim_ls]\n",
    "    print(sim_ls, np.mean(sim_ls))\n",
    "    return np.mean(sim_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def re_map(kv_cache_share_layers_map):\n",
    "    tmp_kv_cache_share_layers_map = {}\n",
    "    for key, values in kv_cache_share_layers_map.items():\n",
    "        if key == values:\n",
    "            tmp_kv_cache_share_layers_map[key] = values\n",
    "        else:\n",
    "            tmp_kv_cache_share_layers_map[key] = tmp_kv_cache_share_layers_map[values]\n",
    "    return tmp_kv_cache_share_layers_map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Strategy Searching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}\n",
    "\n",
    "shared_lay = []\n",
    "shared_num_layers = 0\n",
    "\n",
    "for pair in tqdm(pos_rank):\n",
    "    tmp_kv_cache_share_layers_map = deepcopy(kv_cache_share_layers_map)\n",
    "    if pair[0] < pair[1]:\n",
    "        pair[0], pair[1] = pair[1], pair[0]\n",
    "    if pair[0] in shared_lay:\n",
    "        continue\n",
    "    tmp_kv_cache_share_layers_map[pair[0]] = pair[1]\n",
    "    tmp_kv_cache_share_layers_map = re_map(tmp_kv_cache_share_layers_map)\n",
    "    sim_value = cal_last_hidden_sim(llama, llama, tmp_kv_cache_share_layers_map, tokenizer, calibration_set)\n",
    "    if sim_value > THRESHOLD:\n",
    "        kv_cache_share_layers_map = deepcopy(tmp_kv_cache_share_layers_map)\n",
    "        shared_lay.append(pair[0])\n",
    "        shared_num_layers += 1\n",
    "    if shared_num_layers >= SHARE_LAYERS:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(kv_cache_share_layers_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference with KVSharer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(model, tokenizer, sent, kv_cache_share_layers_map=None):\n",
    "    inputs = tokenizer(sent, return_tensors='pt')\n",
    "    inputs = inputs.to('cuda:0')\n",
    "    pred = model.generate(**inputs, kv_cache_share_layers_map=kv_cache_share_layers_map, max_new_tokens=256, repetition_penalty=1.1)\n",
    "    print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sent = 'Hello, what is your name'\n",
    "generate(llama, tokenizer, sent, kv_cache_share_layers_map=kv_cache_share_layers_map)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
