{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import DistilBertModel, DistilBertTokenizer\n",
    "\n",
    "# Load DistilBERT model and tokenizer\n",
    "model_name = \"distilbert-base-uncased\"\n",
    "model = DistilBertModel.from_pretrained(model_name, output_hidden_states=True)\n",
    "tokenizer = DistilBertTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X1: torch.Size([1, 12, 768])\n",
      "Shape of W2: torch.Size([768, 768])\n",
      "W2: torch.Size([768, 768]), X2: torch.Size([12, 768])\n"
     ]
    }
   ],
   "source": [
    "# Tokenize input\n",
    "input_text = \"This is a sample input to DistilBERT.\"\n",
    "inputs = tokenizer(input_text, return_tensors=\"pt\")\n",
    "\n",
    "# Pass the input through the model\n",
    "outputs = model(**inputs)\n",
    "hidden_states = outputs.hidden_states\n",
    "\n",
    "# Get the first hidden layer's output (intermediate input X1)\n",
    "X1 = hidden_states[1]\n",
    "\n",
    "# Get the weights of the first attention layer's key matrix (W2)\n",
    "W2 = model.transformer.layer[2].attention.q_lin.weight.data\n",
    "# Print shapes\n",
    "print(f\"Shape of X1: {X1.shape}\")\n",
    "print(f\"Shape of W2: {W2.shape}\")\n",
    "\n",
    "X2 = X1[0, :, :]\n",
    "print(f\"W2: {W2.shape}, X2: {X2.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reconstruction Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mse_loss: 0.034828413277864456\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def compute_mse_with_svd_lowrank(W2, X2, rank):\n",
    "    # Step 1: Calculate Y1\n",
    "    Y1 = torch.matmul(X2, W2)\n",
    "\n",
    "    # Step 2: Perform low-rank SVD on W2\n",
    "    U, S, V = torch.svd_lowrank(W2, q=rank)\n",
    "\n",
    "    # Step 3: Reconstruct W2 using the top rank singular values\n",
    "    W_reconstructed = torch.matmul(U, torch.matmul(torch.diag(S), V.t()))\n",
    "\n",
    "    # Step 4: Calculate Y_constructed\n",
    "    Y_constructed = torch.matmul(X2, W_reconstructed)\n",
    "\n",
    "    # Step 5: Calculate MSE loss between Y1 and Y_constructed\n",
    "    mse_loss = F.mse_loss(Y1, Y_constructed)\n",
    "\n",
    "    return Y1, Y_constructed, mse_loss.item()\n",
    "\n",
    "rank = len(W2) // 2  # This gives 384\n",
    "Y1, Y_constructed, mse_loss = compute_mse_with_svd_lowrank(W2, X2, rank)\n",
    "\n",
    "# Step 5: Calculate MSE loss between Y1 and Y_constructed\n",
    "mse_loss = F.mse_loss(Y1, Y_constructed, reduction='mean')\n",
    "print(f'mse_loss: {mse_loss}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ASVD Logic\n",
    "* Input X: (m, n)\n",
    "* S: X2.abs().mean(0): (n, 1)\n",
    "    * Calculates importance of each column in X. \n",
    "* W: (n, q) \n",
    "* Scale W with S: W * s: Scales each row in by n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X2: torch.Size([12, 768]), U: torch.Size([768, 384]) S: torch.Size([384]) V: torch.Size([768, 384]) S2: torch.Size([768, 1])\n",
      "S2: torch.Size([768, 1]), W_reconstructed: torch.Size([768, 768]), Y_constructed: torch.Size([12, 768])\n",
      "mse_loss: 0.023501595482230186\n"
     ]
    }
   ],
   "source": [
    "def compute_asvd(W2, X2, rank):\n",
    "    # Step 1: Calculate Y1\n",
    "    Y1 = torch.matmul(X2, W2)\n",
    "\n",
    "    # S2 = X2.abs().mean(0, keepdims=True)\n",
    "    S2 = X2.abs().mean(0).unsqueeze(1)\n",
    "\n",
    "    # Step 2: Perform low-rank SVD on W2\n",
    "    U, S, V = torch.svd_lowrank(W2 * S2, q=rank)\n",
    "    print(f\"X2: {X2.shape}, U: {U.shape} S: {S.shape} V: {V.shape} S2: {S2.shape}\")\n",
    "\n",
    "    # Step 3: Reconstruct W2 using the top rank singular values\n",
    "    W_reconstructed = torch.matmul(U, torch.matmul(torch.diag(S), V.t()))\n",
    "\n",
    "    # Step 4: Calculate Y_constructed\n",
    "    Y_constructed = torch.matmul(X2, W_reconstructed/S2)\n",
    "\n",
    "    # Step 5: Calculate MSE loss between Y1 and Y_constructed\n",
    "    mse_loss = F.mse_loss(Y1, Y_constructed)\n",
    "\n",
    "    print(f\"S2: {S2.shape}, W_reconstructed: {W_reconstructed.shape}, Y_constructed: {Y_constructed.shape}\")\n",
    "\n",
    "    return Y1, Y_constructed, mse_loss.item()\n",
    "\n",
    "_, _, mse_loss = compute_asvd(W2, X2, rank)\n",
    "print(f'mse_loss: {mse_loss}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reproduce ASVD Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mse_loss: 0.03417380526661873\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class SVDLinear(nn.Module):\n",
    "    def __init__(self, U, S, V) -> None:\n",
    "        super().__init__()\n",
    "        \n",
    "        self.U =( U * S.sqrt()[None, :]).T\n",
    "        self.V = V.t().mul(S.sqrt().view(-1, 1)).contiguous().T\n",
    "\n",
    "    def forward(self, inp):\n",
    "        # compute USV^Tx + b\n",
    "        y = inp @ self.V\n",
    "        y = y @ self.U\n",
    "\n",
    "        return y\n",
    "\n",
    "U, S, V = torch.svd_lowrank(W2.T, q=rank)\n",
    "layer = SVDLinear(U, S, V)\n",
    "Y = X2 @ W2\n",
    "Y_constructed = layer(X2)\n",
    "mse_loss = F.mse_loss(Y, Y_constructed)\n",
    "print(f\"mse_loss: {mse_loss}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Singular Value Validation\n",
    "* If we use full rank, and then select top 50% or just select 50% rank, which is better loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X1: torch.Size([1, 12, 768])\n",
      "Shape of W2: torch.Size([768, 768])\n"
     ]
    }
   ],
   "source": [
    "# Tokenize input\n",
    "input_text = \"This is a sample input to DistilBERT.\"\n",
    "inputs = tokenizer(input_text, return_tensors=\"pt\")\n",
    "\n",
    "# Pass the input through the model\n",
    "outputs = model(**inputs)\n",
    "hidden_states = outputs.hidden_states\n",
    "\n",
    "# Get the first hidden layer's output (intermediate input X1)\n",
    "X1 = hidden_states[1]\n",
    "\n",
    "# Get the weights of the first attention layer's key matrix (W2)\n",
    "W2 = model.transformer.layer[2].attention.q_lin.weight.data\n",
    "# Print shapes\n",
    "print(f\"Shape of X1: {X1.shape}\")\n",
    "print(f\"Shape of W2: {W2.shape}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_singular_values 153\n",
      "153\n",
      "MSE Error 1:  0.000760575057938695\n",
      "MSE Error 2:  0.0007566662970930338\n"
     ]
    }
   ],
   "source": [
    "def low_rank_svd_reconstruction(W2, rank, pct_singular_values):\n",
    "    # Step 2: Perform low-rank SVD on W2\n",
    "    U, S, V = torch.svd_lowrank(W2, q=rank)\n",
    "\n",
    "    # Determine the number of singular values to use\n",
    "    if pct_singular_values < 1:\n",
    "        num_singular_values = int(len(S) * pct_singular_values)\n",
    "        print('num_singular_values', num_singular_values)\n",
    "        S = S[:num_singular_values]\n",
    "        U = U[:, :num_singular_values]\n",
    "        V = V[:, :num_singular_values]\n",
    "\n",
    "    # Reconstruct W using the specified singular values\n",
    "    W_reconstructed = torch.matmul(U, torch.matmul(torch.diag(S), V.t()))\n",
    "    \n",
    "    return W_reconstructed\n",
    "\n",
    "W2 = W2.to(torch.float32)\n",
    "rank = 768 \n",
    "pct_singular_values = 0.2\n",
    "W_r = low_rank_svd_reconstruction(W2, rank, pct_singular_values)\n",
    "\n",
    "rank2 = int(rank * pct_singular_values)\n",
    "print(rank2)\n",
    "W_r2 = low_rank_svd_reconstruction(W2, rank2, 1.0)\n",
    "\n",
    "print('MSE Error 1: ', ((W2-W_r)**2).mean().item())\n",
    "print('MSE Error 2: ', ((W2-W_r2)**2).mean().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0007570076268166304"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((W2-W_r2)**2).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
