{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# checkpoints/llama_130m-2023-05-16-18-18-14\n",
    "\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM\n",
    "from peft_pretraining.relora import ReLoRaModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls ../checkpoints/llama_130m-2023-05-14-19-54-05/model_5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_5K = LlamaForCausalLM.from_pretrained(\"../checkpoints/llama_130m-2023-05-14-19-54-05/model_5000\")\n",
    "model_20K = ReLoRaModel.from_pretrained(\"../checkpoints/llama_130m-2023-05-16-18-18-14/model_20000\")\n",
    "full_model_20K = LlamaForCausalLM.from_pretrained(\"../checkpoints/llama_130m-2023-05-14-19-54-05/model_20000\")\n",
    "\n",
    "full_peft = ReLoRaModel.from_pretrained(\"../checkpoints/llama_130m-2023-05-09-19-17-11/model_10000\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get singular values of all layers\n",
    "q_projs = []\n",
    "k_projs = []\n",
    "v_projs = []\n",
    "o_projs = []\n",
    "gate_projs = []\n",
    "down_projs = []\n",
    "up_projs = []\n",
    "\n",
    "for layer in tqdm(model_5K.model.layers):\n",
    "    q_projs_weight = layer.self_attn.q_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(q_projs_weight).S.cpu().numpy()\n",
    "    q_projs.append(singular_values)\n",
    "\n",
    "    k_projs_weight = layer.self_attn.k_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(k_projs_weight).S.cpu().numpy()\n",
    "    k_projs.append(singular_values)\n",
    "\n",
    "    v_projs_weight = layer.self_attn.v_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(v_projs_weight).S.cpu().numpy()\n",
    "    v_projs.append(singular_values)\n",
    "\n",
    "    o_projs_weight = layer.self_attn.o_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(o_projs_weight).S.cpu().numpy()\n",
    "    o_projs.append(singular_values)\n",
    "\n",
    "    gate_projs_weight = layer.mlp.gate_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(gate_projs_weight).S.cpu().numpy()\n",
    "    gate_projs.append(singular_values)\n",
    "\n",
    "    down_projs_weight = layer.mlp.down_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(down_projs_weight).S.cpu().numpy()\n",
    "    down_projs.append(singular_values)\n",
    "\n",
    "    up_projs_weight = layer.mlp.up_proj.weight.detach().cuda()\n",
    "    singular_values = torch.svd(up_projs_weight).S.cpu().numpy()\n",
    "    up_projs.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "def get_linear_weight_from_relora(relora_layer):\n",
    "    return relora_layer.weight + relora_layer.lora_B.weight @ relora_layer.lora_A.weight * relora_layer.scaling\n",
    "\n",
    "# get singular values of all layers\n",
    "peft_q_projs = []\n",
    "peft_k_projs = []\n",
    "peft_v_projs = []\n",
    "peft_o_projs = []\n",
    "peft_gate_projs = []\n",
    "peft_down_projs = []\n",
    "peft_up_projs = []\n",
    "\n",
    "for layer in tqdm(model_20K.wrapped_model.model.layers):\n",
    "    q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    peft_q_projs.append(singular_values)\n",
    "\n",
    "    k_projs_weight = get_linear_weight_from_relora(layer.self_attn.k_proj).detach()\n",
    "    singular_values = torch.svd(k_projs_weight).S\n",
    "    peft_k_projs.append(singular_values)\n",
    "\n",
    "    v_projs_weight = get_linear_weight_from_relora(layer.self_attn.v_proj).detach()\n",
    "    singular_values = torch.svd(v_projs_weight).S\n",
    "    peft_v_projs.append(singular_values)\n",
    "\n",
    "    o_projs_weight = get_linear_weight_from_relora(layer.self_attn.o_proj).detach()\n",
    "    singular_values = torch.svd(o_projs_weight).S\n",
    "    peft_o_projs.append(singular_values)\n",
    "\n",
    "    gate_projs_weight = get_linear_weight_from_relora(layer.mlp.gate_proj).detach()\n",
    "    singular_values = torch.svd(gate_projs_weight).S\n",
    "    peft_gate_projs.append(singular_values)\n",
    "\n",
    "    down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach()\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    peft_down_projs.append(singular_values)\n",
    "\n",
    "    up_projs_weight = get_linear_weight_from_relora(layer.mlp.up_proj).detach()\n",
    "    singular_values = torch.svd(up_projs_weight).S\n",
    "    peft_up_projs.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now, deltas\n",
    "\n",
    "q_projs_delta = []\n",
    "k_projs_delta = []\n",
    "v_projs_delta = []\n",
    "o_projs_delta = []\n",
    "gate_projs_delta = []\n",
    "down_projs_delta = []\n",
    "up_projs_delta = []\n",
    "\n",
    "for _5k, _20k in zip(model_5K.model.layers, model_20K.wrapped_model.model.layers):\n",
    "    q_projs_weight_5k = (_5k.self_attn.q_proj).weight.detach()\n",
    "    q_projs_weight_20k = get_linear_weight_from_relora(_20k.self_attn.q_proj).detach()\n",
    "    q_projs_weight_delta = q_projs_weight_20k - q_projs_weight_5k\n",
    "    singular_values = torch.svd(q_projs_weight_delta).S\n",
    "    q_projs_delta.append(singular_values)\n",
    "\n",
    "    k_projs_weight_5k = (_5k.self_attn.k_proj).weight.detach()\n",
    "    k_projs_weight_20k = get_linear_weight_from_relora(_20k.self_attn.k_proj).detach()\n",
    "    k_projs_weight_delta = k_projs_weight_20k - k_projs_weight_5k\n",
    "    singular_values = torch.svd(k_projs_weight_delta).S\n",
    "    k_projs_delta.append(singular_values)\n",
    "\n",
    "    v_projs_weight_5k = (_5k.self_attn.v_proj).weight.detach()\n",
    "    v_projs_weight_20k = get_linear_weight_from_relora(_20k.self_attn.v_proj).detach()\n",
    "    v_projs_weight_delta = v_projs_weight_20k - v_projs_weight_5k\n",
    "    singular_values = torch.svd(v_projs_weight_delta).S\n",
    "    v_projs_delta.append(singular_values)\n",
    "\n",
    "    o_projs_weight_5k = (_5k.self_attn.o_proj).weight.detach()\n",
    "    o_projs_weight_20k = get_linear_weight_from_relora(_20k.self_attn.o_proj).detach()\n",
    "    o_projs_weight_delta = o_projs_weight_20k - o_projs_weight_5k\n",
    "    singular_values = torch.svd(o_projs_weight_delta).S\n",
    "    o_projs_delta.append(singular_values)\n",
    "\n",
    "    gate_projs_weight_5k = (_5k.mlp.gate_proj).weight.detach()\n",
    "    gate_projs_weight_20k = get_linear_weight_from_relora(_20k.mlp.gate_proj).detach()\n",
    "    gate_projs_weight_delta = gate_projs_weight_20k - gate_projs_weight_5k\n",
    "    singular_values = torch.svd(gate_projs_weight_delta).S\n",
    "    gate_projs_delta.append(singular_values)\n",
    "\n",
    "    down_projs_weight_5k = (_5k.mlp.down_proj).weight.detach()\n",
    "    down_projs_weight_20k = get_linear_weight_from_relora(_20k.mlp.down_proj).detach()\n",
    "    down_projs_weight_delta = down_projs_weight_20k - down_projs_weight_5k\n",
    "    singular_values = torch.svd(down_projs_weight_delta).S\n",
    "    down_projs_delta.append(singular_values)\n",
    "\n",
    "    up_projs_weight_5k = (_5k.mlp.up_proj).weight.detach()\n",
    "    up_projs_weight_20k = get_linear_weight_from_relora(_20k.mlp.up_proj).detach()\n",
    "    up_projs_weight_delta = up_projs_weight_20k - up_projs_weight_5k\n",
    "    singular_values = torch.svd(up_projs_weight_delta).S\n",
    "    up_projs_delta.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "def get_linear_weight_from_relora(relora_layer):\n",
    "    return relora_layer.lora_B.weight @ relora_layer.lora_A.weight * relora_layer.scaling\n",
    "\n",
    "# get singular values of all layers\n",
    "full_peft_q_projs = []\n",
    "full_peft_k_projs = []\n",
    "full_peft_v_projs = []\n",
    "full_peft_o_projs = []\n",
    "full_peft_gate_projs = []\n",
    "full_peft_down_projs = []\n",
    "full_peft_up_projs = []\n",
    "\n",
    "for layer in tqdm(full_peft.wrapped_model.model.layers):\n",
    "    q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n",
    "    singular_values = torch.svd(q_projs_weight).S\n",
    "    full_peft_q_projs.append(singular_values)\n",
    "\n",
    "    k_projs_weight = get_linear_weight_from_relora(layer.self_attn.k_proj).detach()\n",
    "    singular_values = torch.svd(k_projs_weight).S\n",
    "    full_peft_k_projs.append(singular_values)\n",
    "\n",
    "    v_projs_weight = get_linear_weight_from_relora(layer.self_attn.v_proj).detach()\n",
    "    singular_values = torch.svd(v_projs_weight).S\n",
    "    full_peft_v_projs.append(singular_values)\n",
    "\n",
    "    o_projs_weight = get_linear_weight_from_relora(layer.self_attn.o_proj).detach()\n",
    "    singular_values = torch.svd(o_projs_weight).S\n",
    "    full_peft_o_projs.append(singular_values)\n",
    "\n",
    "    gate_projs_weight = get_linear_weight_from_relora(layer.mlp.gate_proj).detach()\n",
    "    singular_values = torch.svd(gate_projs_weight).S\n",
    "    full_peft_gate_projs.append(singular_values)\n",
    "\n",
    "    down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach()\n",
    "    singular_values = torch.svd(down_projs_weight).S\n",
    "    full_peft_down_projs.append(singular_values)\n",
    "\n",
    "    up_projs_weight = get_linear_weight_from_relora(layer.mlp.up_proj).detach()\n",
    "    singular_values = torch.svd(up_projs_weight).S\n",
    "    full_peft_up_projs.append(singular_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# delta between full_model_20K and model_5K\n",
    "\n",
    "q_projs_delta_full = []\n",
    "k_projs_delta_full = []\n",
    "v_projs_delta_full = []\n",
    "o_projs_delta_full = []\n",
    "gate_projs_delta_full = []\n",
    "down_projs_delta_full = []\n",
    "up_projs_delta_full = []\n",
    "\n",
    "for layer_20k, layer_5k in zip(full_model_20K.model.layers, model_5K.model.layers):\n",
    "    q_projs_weight_20k = layer_20k.self_attn.q_proj.weight.detach()\n",
    "    q_projs_weight_5k = layer_5k.self_attn.q_proj.weight.detach()\n",
    "    q_projs_weight_delta = q_projs_weight_20k - q_projs_weight_5k\n",
    "    singular_values = torch.svd(q_projs_weight_delta).S\n",
    "    q_projs_delta_full.append(singular_values)\n",
    "\n",
    "    k_projs_weight_20k = layer_20k.self_attn.k_proj.weight.detach()\n",
    "    k_projs_weight_5k = layer_5k.self_attn.k_proj.weight.detach()\n",
    "    k_projs_weight_delta = k_projs_weight_20k - k_projs_weight_5k\n",
    "    singular_values = torch.svd(k_projs_weight_delta).S\n",
    "    k_projs_delta_full.append(singular_values)\n",
    "\n",
    "    v_projs_weight_20k = layer_20k.self_attn.v_proj.weight.detach()\n",
    "    v_projs_weight_5k = layer_5k.self_attn.v_proj.weight.detach()\n",
    "    v_projs_weight_delta = v_projs_weight_20k - v_projs_weight_5k\n",
    "    singular_values = torch.svd(v_projs_weight_delta).S\n",
    "    v_projs_delta_full.append(singular_values)\n",
    "\n",
    "    o_projs_weight_20k = layer_20k.self_attn.o_proj.weight.detach()\n",
    "    o_projs_weight_5k = layer_5k.self_attn.o_proj.weight.detach()\n",
    "    o_projs_weight_delta = o_projs_weight_20k - o_projs_weight_5k\n",
    "    singular_values = torch.svd(o_projs_weight_delta).S\n",
    "    o_projs_delta_full.append(singular_values)\n",
    "\n",
    "    gate_projs_weight_20k = layer_20k.mlp.gate_proj.weight.detach()\n",
    "    gate_projs_weight_5k = layer_5k.mlp.gate_proj.weight.detach()\n",
    "    gate_projs_weight_delta = gate_projs_weight_20k - gate_projs_weight_5k\n",
    "    singular_values = torch.svd(gate_projs_weight_delta).S\n",
    "    gate_projs_delta_full.append(singular_values)\n",
    "\n",
    "    down_projs_weight_20k = layer_20k.mlp.down_proj.weight.detach()\n",
    "    down_projs_weight_5k = layer_5k.mlp.down_proj.weight.detach()\n",
    "    down_projs_weight_delta = down_projs_weight_20k - down_projs_weight_5k\n",
    "    singular_values = torch.svd(down_projs_weight_delta).S\n",
    "    down_projs_delta_full.append(singular_values)\n",
    "\n",
    "    up_projs_weight_20k = layer_20k.mlp.up_proj.weight.detach()\n",
    "    up_projs_weight_5k = layer_5k.mlp.up_proj.weight.detach()\n",
    "    up_projs_weight_delta = up_projs_weight_20k - up_projs_weight_5k\n",
    "    singular_values = torch.svd(up_projs_weight_delta).S\n",
    "    up_projs_delta_full.append(singular_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot histogram of singular values for q_projs over layers\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=200)\n",
    "ax.set_title(\"Singular Values of Q Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "\n",
    "ax.hist(torch.cat(q_projs_delta).numpy(), density=True, bins=100, alpha=0.9, label=\"ReLoRA Delta\")\n",
    "ax.hist(torch.cat(full_peft_q_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"LoRA Delta\")\n",
    "ax.hist(torch.cat(q_projs_delta_full).numpy(), density=True, bins=100, alpha=0.3, label=\"Delta between full models\")\n",
    "\n",
    "# print numbers of singular values < 0.1\n",
    "print(\"ReLoRA Delta: \", (torch.cat(q_projs_delta).numpy() < 0.1).sum())\n",
    "print(\"LoRA Delta: \", (torch.cat(full_peft_q_projs).numpy() < 0.1).sum())\n",
    "print(\"Delta between full models: \", (torch.cat(q_projs_delta_full).numpy() < 0.1).sum())\n",
    "\n",
    "# ylim\n",
    "ax.set_ylim(0, 3)\n",
    "ax.set_xlim(0, 4)\n",
    "\n",
    "ax.legend()\n",
    "\n",
    "# pdf\n",
    "plt.savefig(\"q_proj_singular_values.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot histogram of singular values for q_projs over layers\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=200)\n",
    "ax.set_title(\"Singular Values of Up Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "\n",
    "ax.hist(torch.cat(up_projs_delta).numpy(), density=True, bins=100, alpha=0.9, label=\"ReLoRA Delta\")\n",
    "ax.hist(torch.cat(full_peft_up_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"LoRA Delta\")\n",
    "ax.hist(torch.cat(up_projs_delta_full).numpy(), density=True, bins=100, alpha=0.3, label=\"Delta between full models\")\n",
    "\n",
    "# print numbers of singular values < 0.1\n",
    "print(\"ReLoRA Delta: \", (torch.cat(up_projs_delta).numpy() < 0.1).sum())\n",
    "print(\"LoRA Delta: \", (torch.cat(full_peft_up_projs).numpy() < 0.1).sum())\n",
    "print(\"Delta between full models: \", (torch.cat(up_projs_delta_full).numpy() < 0.1).sum())\n",
    "\n",
    "# # ylim\n",
    "ax.set_ylim(0, 3)\n",
    "ax.set_xlim(0, 2)\n",
    "\n",
    "ax.legend()\n",
    "\n",
    "# pdf\n",
    "plt.savefig(\"up_proj_singular_values.pdf\", bbox_inches='tight')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=200)\n",
    "ax.set_title(\"Singular Values of V Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "\n",
    "ax.hist(torch.cat(v_projs_delta).numpy(), density=True, bins=100, alpha=0.9, label=\"ReLoRA Delta\")\n",
    "ax.hist(torch.cat(full_peft_v_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"LoRA Delta\")\n",
    "ax.hist(torch.cat(v_projs_delta_full).numpy(), density=True, bins=100, alpha=0.3, label=\"Delta between full models\")\n",
    "\n",
    "# Print the number of singular values < 0.1\n",
    "print(\"ReLoRA Delta: \", (torch.cat(v_projs_delta).numpy() < 0.1).sum())\n",
    "print(\"LoRA Delta: \", (torch.cat(full_peft_v_projs).numpy() < 0.1).sum())\n",
    "print(\"Delta between full models: \", (torch.cat(v_projs_delta_full).numpy() < 0.1).sum())\n",
    "\n",
    "# Set the y-axis limits\n",
    "ax.set_ylim(0, 3)\n",
    "ax.set_xlim(0, 2)\n",
    "\n",
    "ax.legend()\n",
    "\n",
    "# Save the figure as a PDF\n",
    "plt.savefig(\"v_proj_singular_values.pdf\", bbox_inches='tight')\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=200)\n",
    "ax.set_title(\"Singular Values of Down Projections\")\n",
    "ax.set_xlabel(\"Singular Value\")\n",
    "ax.set_ylabel(\"Frequency\")\n",
    "\n",
    "ax.hist(torch.cat(down_projs_delta).numpy(), density=True, bins=100, alpha=0.9, label=\"ReLoRA Delta\")\n",
    "ax.hist(torch.cat(full_peft_down_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"LoRA Delta\")\n",
    "ax.hist(torch.cat(down_projs_delta_full).numpy(), density=True, bins=100, alpha=0.3, label=\"Delta between full models\")\n",
    "\n",
    "# Print the number of singular values < 0.1\n",
    "print(\"ReLoRA Delta: \", (torch.cat(down_projs_delta).numpy() < 0.1).sum())\n",
    "print(\"LoRA Delta: \", (torch.cat(full_peft_down_projs).numpy() < 0.1).sum())\n",
    "print(\"Delta between full models: \", (torch.cat(down_projs_delta_full).numpy() < 0.1).sum())\n",
    "\n",
    "# Set the y-axis limits\n",
    "ax.set_ylim(0, 3)\n",
    "ax.set_xlim(0, 2)\n",
    "\n",
    "ax.legend()\n",
    "\n",
    "# Save the figure as a PDF\n",
    "plt.savefig(\"down_proj_singular_values.pdf\", bbox_inches='tight')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 4, figsize=(16, 4), dpi=150)\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "titles = [\"Q Projections\", \"V Projections\", \"Up Projections\", \"Down Projections\"]\n",
    "delta_data = [q_projs_delta, v_projs_delta, up_projs_delta, down_projs_delta]\n",
    "full_delta_data = [q_projs_delta_full, v_projs_delta_full, up_projs_delta_full, down_projs_delta_full]\n",
    "lora_delta_data = [full_peft_q_projs, full_peft_v_projs, full_peft_up_projs, full_peft_down_projs]\n",
    "\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.set_title(titles[i])\n",
    "    ax.set_xlabel(\"Singular Value\", fontsize=16)\n",
    "    if i == 0:\n",
    "        ax.set_ylabel(\"Frequency\", fontsize=16)\n",
    "\n",
    "    ax.hist(torch.cat(delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.9, label=\"ReLoRA\")\n",
    "    ax.hist(torch.cat(lora_delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.5, label=\"LoRA\")\n",
    "    ax.hist(torch.cat(full_delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.3, label=\"Full-rank\\ntraining\")\n",
    "\n",
    "    # Print the number of singular values < 0.1\n",
    "    print(f\"Number of singular values < 0.1 ReLoRA ({titles[i]}): \", (torch.cat(delta_data[i]).numpy() < 0.1).sum())\n",
    "    print(f\"Number of singular values < 0.1 full-rank training ({titles[i]}): \", (torch.cat(full_delta_data[i]).numpy() < 0.1).sum())\n",
    "    print(f\"Number of singular values < 0.1 LoRA ({titles[i]}): \", (torch.cat(lora_delta_data[i]).numpy() < 0.1).sum())\n",
    "\n",
    "    # Set the y-axis limits\n",
    "    ax.set_ylim(0, 3)\n",
    "    ax.set_xlim(0, 2)\n",
    "\n",
    "    # Set legend font size\n",
    "    ax.legend(fontsize=16)\n",
    "\n",
    "# Save the figure as a PDF\n",
    "plt.savefig(\"projection_singular_values.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 5, figsize=(16, 4), dpi=150)\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "titles = [\"Q Projections\", \"K Projections\", \"V Projections\", \"Up Projections\", \"Down Projections\"]\n",
    "delta_data = [q_projs_delta, k_projs_delta, v_projs_delta, up_projs_delta, down_projs_delta]\n",
    "full_delta_data = [q_projs_delta_full, k_projs_delta_full, v_projs_delta_full, up_projs_delta_full, down_projs_delta_full]\n",
    "lora_delta_data = [full_peft_q_projs, full_peft_q_projs, full_peft_v_projs, full_peft_up_projs, full_peft_down_projs]\n",
    "\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.set_title(titles[i])\n",
    "    ax.set_xlabel(\"Singular Value\", fontsize=16)\n",
    "    if i == 0:\n",
    "        ax.set_ylabel(\"Frequency\", fontsize=16)\n",
    "\n",
    "    ax.hist(torch.cat(delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.9, label=\"ReLoRA\")\n",
    "    ax.hist(torch.cat(lora_delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.5, label=\"LoRA\")\n",
    "    ax.hist(torch.cat(full_delta_data[i]).numpy(), density=True, bins=50, range=(0, 2), alpha=0.3, label=\"Full-rank\\ntraining\")\n",
    "\n",
    "    # Print the number of singular values < 0.1\n",
    "    print(f\"Number of singular values < 0.1 ReLoRA ({titles[i]}): \", (torch.cat(delta_data[i]).numpy() < 0.1).sum())\n",
    "    print(f\"Number of singular values < 0.1 full-rank training ({titles[i]}): \", (torch.cat(full_delta_data[i]).numpy() < 0.1).sum())\n",
    "    print(f\"Number of singular values < 0.1 LoRA ({titles[i]}): \", (torch.cat(lora_delta_data[i]).numpy() < 0.1).sum())\n",
    "\n",
    "    # Set the y-axis limits\n",
    "    ax.set_ylim(0, 3)\n",
    "    ax.set_xlim(0, 2)\n",
    "\n",
    "    # Set legend font size\n",
    "    ax.legend(fontsize=16)\n",
    "\n",
    "# Save the figure as a PDF\n",
    "plt.savefig(\"projection_singular_values.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
