{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM\n",
    "from peft_pretraining.relora import ReLoRaModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls ../checkpoints/llama_130m-2023-05-09-19-17-11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# llama_130m-2023-05-09-19-18-46 is major-moon-154\n",
    "regular_250M = AutoModelForCausalLM.from_pretrained(\"../checkpoints/llama_60m-2023-05-13-17-13-02/model_10000\")\n",
    "\n",
    "peft_250M = ReLoRaModel.from_pretrained(\"../checkpoints/llama_130m-2023-05-09-19-17-11/model_10000\")\n",
    "\n",
    "# llama_7b = AutoModelForCausalLM.from_pretrained(\"huggyllama/llama-7b\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get singular values of all layers\n",
    "q_projs = []\n",
    "k_projs = []\n",
    "v_projs = []\n",
    "o_projs = []\n",
    "gate_projs = []\n",
    "down_projs = []\n",
    "up_projs = []\n",
    "\n",
    "for layer in tqdm(regular_250M.model.layers):\n",
    "    q_projs_weight = layer.self_attn.q_proj.weight.detach()\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    q_projs.append(singular_values)\n",
    "\n",
    "    k_projs_weight = layer.self_attn.k_proj.weight.detach()\n",
    "    singular_values = torch.svd(k_projs_weight).S\n",
    "    k_projs.append(singular_values)\n",
    "\n",
    "    v_projs_weight = layer.self_attn.v_proj.weight.detach()\n",
    "    singular_values = torch.svd(v_projs_weight).S\n",
    "    v_projs.append(singular_values)\n",
    "\n",
    "    o_projs_weight = layer.self_attn.o_proj.weight.detach()\n",
    "    singular_values = torch.svd(o_projs_weight).S\n",
    "    o_projs.append(singular_values)\n",
    "\n",
    "    gate_projs_weight = layer.mlp.gate_proj.weight.detach()\n",
    "    singular_values = torch.svd(gate_projs_weight).S\n",
    "    gate_projs.append(singular_values)\n",
    "\n",
    "    down_projs_weight = layer.mlp.down_proj.weight.detach()\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    down_projs.append(singular_values)\n",
    "\n",
    "    up_projs_weight = layer.mlp.up_proj.weight.detach()\n",
    "    singular_values = torch.svd(up_projs_weight).S\n",
    "    up_projs.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "def get_linear_weight_from_relora(relora_layer):\n",
    "    return relora_layer.lora_B.weight @ relora_layer.lora_A.weight * relora_layer.scaling\n",
    "\n",
    "# get singular values of all layers\n",
    "peft_q_projs = []\n",
    "peft_k_projs = []\n",
    "peft_v_projs = []\n",
    "peft_o_projs = []\n",
    "peft_gate_projs = []\n",
    "peft_down_projs = []\n",
    "peft_up_projs = []\n",
    "\n",
    "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n",
    "    q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    peft_q_projs.append(singular_values)\n",
    "\n",
    "    k_projs_weight = get_linear_weight_from_relora(layer.self_attn.k_proj).detach()\n",
    "    singular_values = torch.svd(k_projs_weight).S\n",
    "    peft_k_projs.append(singular_values)\n",
    "\n",
    "    v_projs_weight = get_linear_weight_from_relora(layer.self_attn.v_proj).detach()\n",
    "    singular_values = torch.svd(v_projs_weight).S\n",
    "    peft_v_projs.append(singular_values)\n",
    "\n",
    "    o_projs_weight = get_linear_weight_from_relora(layer.self_attn.o_proj).detach()\n",
    "    singular_values = torch.svd(o_projs_weight).S\n",
    "    peft_o_projs.append(singular_values)\n",
    "\n",
    "    gate_projs_weight = get_linear_weight_from_relora(layer.mlp.gate_proj).detach()\n",
    "    singular_values = torch.svd(gate_projs_weight).S\n",
    "    peft_gate_projs.append(singular_values)\n",
    "\n",
    "    down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach()\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    peft_down_projs.append(singular_values)\n",
    "\n",
    "    up_projs_weight = get_linear_weight_from_relora(layer.mlp.up_proj).detach()\n",
    "    singular_values = torch.svd(up_projs_weight).S\n",
    "    peft_up_projs.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get singular values of all layers\n",
    "# noised_peft_q_projs = []\n",
    "noised_peft_q_projs = []\n",
    "noised_peft_down_projs = []\n",
    "\n",
    "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n",
    "    q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach() + torch.randn_like(get_linear_weight_from_relora(layer.self_attn.q_proj).detach()) * 0.04\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    noised_peft_q_projs.append(singular_values)\n",
    "\n",
    "    # down proj\n",
    "    down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach() + torch.randn_like(get_linear_weight_from_relora(layer.mlp.down_proj).detach()) * 0.04\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    noised_peft_down_projs.append(singular_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# perform magtinude pruning\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "pruned_peft_q_projs = []\n",
    "\n",
    "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n",
    "    q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n",
    "\n",
    "    threshold_90p = np.percentile(q_projs_weight.abs().numpy(), 0.01)\n",
    "    q_projs_weight = q_projs_weight * (q_projs_weight.abs() > threshold_90p)\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    pruned_peft_q_projs.append(singular_values)\n",
    "\n",
    "# for regular\n",
    "\n",
    "pruned_q_projs = []\n",
    "\n",
    "for layer in tqdm(regular_250M.model.layers):\n",
    "    q_projs_weight = layer.self_attn.q_proj.weight.detach()\n",
    "\n",
    "    threshold_90p = np.percentile(q_projs_weight.abs(), 0.01)\n",
    "    q_projs_weight = q_projs_weight * (q_projs_weight.abs() > threshold_90p)\n",
    "    # q_projs_weight = torch.zeros_like(q_projs_weight)\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    pruned_q_projs.append(singular_values)\n",
    "\n",
    "# random pruning of regular\n",
    "\n",
    "random_pruned_q_projs = []\n",
    "\n",
    "for layer in tqdm(regular_250M.model.layers):\n",
    "    q_projs_weight = layer.self_attn.q_proj.weight.detach()\n",
    "\n",
    "    threshold_90p = np.percentile(q_projs_weight.abs(), 0.01)\n",
    "    q_projs_weight = q_projs_weight * (torch.rand_like(q_projs_weight) > threshold_90p)\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    random_pruned_q_projs.append(singular_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prune down projection\n",
    "\n",
    "pruned_down_projs = []\n",
    "\n",
    "for layer in tqdm(regular_250M.model.layers):\n",
    "    down_projs_weight = layer.mlp.down_proj.weight.detach()\n",
    "\n",
    "    threshold_90p = np.percentile(down_projs_weight.abs(), 0.01)\n",
    "    down_projs_weight = down_projs_weight * (down_projs_weight.abs() > threshold_90p)\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    pruned_down_projs.append(singular_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot histogram of singular values for q_projs over layers\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)\n",
    "ax.set_title(\"Singular Values of Q Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "# ax.hist(torch.cat(noised_peft_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Noised PEFT\")\n",
    "ax.hist(torch.cat(q_projs).numpy(), density=True, bins=100, alpha=0.8, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_q_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"PEFT\")\n",
    "# ax.hist(torch.cat(pruned_peft_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Pruned PEFT\")\n",
    "ax.hist(torch.cat(pruned_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Pruned Regular\")\n",
    "# ax.hist(torch.cat(random_pruned_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Random Pruned Regular\")\n",
    "\n",
    "# ylim\n",
    "ax.set_ylim(0, 4)\n",
    "ax.set_xlim(0, 6)\n",
    "\n",
    "ax.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot histogram of singular values for k_projs over layers\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)\n",
    "ax.set_title(\"Singular Values of K Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(k_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_k_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "ax.hist(torch.cat(llama7b_k_projs).numpy(), bins=100, alpha=0.5, label=\"LLAMA-7B\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "ax.set_title(\"Singular Values of V Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(v_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_v_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "ax.hist(torch.cat(llama7b_v_projs).numpy(), bins=100, alpha=0.5, label=\"LLAMA-7B\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "ax.set_title(\"Singular Values of O Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(o_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_o_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "ax.set_title(\"Singular Values of Up Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(up_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_up_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "ax.set_title(\"Singular Values of Down Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(down_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_down_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "ax.set_title(\"Singular Values of Gate Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "ax.hist(torch.cat(gate_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n",
    "ax.hist(torch.cat(peft_gate_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
