{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "718b01e0-e747-451f-9c64-409c6eaae792",
   "metadata": {},
   "source": [
    "## Various implementations of differentiable topk for singular value selection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c52e678-0dd4-4b27-88a6-dc08f1b27268",
   "metadata": {},
   "source": [
    "### Approach 1: \n",
    "1. Random Variable (RV)\n",
    "   - Create trainable parameter $\\alpha$ of shape (hidden_dim,) which is total num of singular values\n",
    "2. Convert $\\alpha$ to mask:\n",
    "   - Add Gumbel Noise to $\\alpha$ to get tensor with values of ~1 or ~0\n",
    "3. Scale Singular Value diagonal matrix based on this, effectively masking out values\n",
    "4. Add TV Loss to push mask to topk instead of any k\n",
    "    - without tv loss, it can predict [1, 1, 0, 1, 0, 0, 1]\n",
    "    - with loss, itll be forced to predict in sequence "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "d78a3de5-4545-4dcc-a6f8-cbbf54025ad3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0405e-14, 9.4531e-02, 9.1460e-01, 5.8875e-12, 4.1301e-01, 1.0000e+00,\n",
       "         1.1703e-14, 5.8825e-02, 3.4913e-03, 1.3637e-13]])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "tau=0.50\n",
    "probs = torch.clamp(torch.randn(1, 10), 0., 1.)\n",
    "bernoulli=torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature=tau, probs=probs)\n",
    "y_soft = bernoulli.rsample()\n",
    "y_soft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "7d16f4dc-b949-4e04-b1e7-1ddc189f6852",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def get_tv_loss(x):\n",
    "    \"\"\"\n",
    "    Computes the Total Variation (TV) loss for a 1D array.\n",
    "    \n",
    "    Args:\n",
    "        x (torch.Tensor): Input tensor of shape (batch_size, seq_len).\n",
    "        reduction (str): Either 'mean' or 'sum', specifying the reduction method\n",
    "                         to use for the batch dimension.\n",
    "    \n",
    "    Returns:\n",
    "        torch.Tensor: A scalar tensor representing the TV loss.\n",
    "    \"\"\"    \n",
    "    assert len(x.shape) == 1, 'exptected input into get_tv_loss to be of dim 1'\n",
    "    # Compute the absolute differences between neighboring elements\n",
    "    tv = torch.abs(x[1:] - x[:-1])\n",
    "    \n",
    "    # Sum the TV loss over the sequence dimension\n",
    "    # tv_loss = torch.norm(tv,p=2)\n",
    "    tv_loss = tv.mean()\n",
    "    \n",
    "    return tv_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "d268fd6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "da55661a-3d14-4e77-8196-783e37ea1086",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.3333)\n"
     ]
    }
   ],
   "source": [
    "# Example input tensor\n",
    "x = torch.tensor([1., 1., 1., 0., 0., 0., 1])\n",
    "\n",
    "# Compute the TV loss\n",
    "tv_error = get_tv_loss(x)\n",
    "print(tv_error)  # Output: tensor(2.5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ade3b70-c1b1-4974-972a-5b3103dd164c",
   "metadata": {},
   "source": [
    "## Approach 2\n",
    "### Implementation of Differentiable model scaling Diff Topk\n",
    "* https://arxiv.org/pdf/2405.07194\n",
    "\n",
    "#### Pros: \n",
    "* Only a 1 new parameter is introduced\n",
    "\n",
    "#### Challenges:\n",
    "* The mask does not only contain values very close to 1 and very close to 0. Some values are 0.5, 1e-1, 1e-2, which is not 0 and will affect reconstruction, so have to account for the medium small scales during final re-construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "ad86c3e9-f1e3-479c-ac10-8c4e724052d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Topk=2.00 singular will be selected, using predicted topk_ratio=0.98\n",
      "In the below tensor, the top 2.00 values should be 1 and rest 0\n",
      "\n",
      "pred_mask:\n",
      " tensor([1.0000e+00, 9.9998e-01, 4.4466e-01, 1.1967e-05, 1.7886e-10, 2.6730e-15,\n",
      "        3.9949e-20, 5.9704e-25, 8.9229e-30, 1.3335e-34, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "       grad_fn=<SigmoidBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# hidden dim will be 768\n",
    "hidden_dim = 100 \n",
    "\n",
    "def calculate_topk_mask(topk_ratio):\n",
    "    \"\"\"\n",
    "    Inputs:\n",
    "        topk_ratio: float between 0 and 1 that predicts the ratio of singular values to be selected. \n",
    "\n",
    "    Outputs:\n",
    "        mask: tensor containing values 0-1 to serve as mask for singular value selection\n",
    "    \"\"\"\n",
    "    c_prime = torch.linspace(1, 0, 100, requires_grad=True)\n",
    "    \n",
    "    # usually set to the length\n",
    "    lambda_value = len(c_prime) + 1000\n",
    "    alpha = 0.98\n",
    "    \n",
    "    mask = torch.sigmoid(lambda_value * (c_prime-alpha))\n",
    "    k = (1-alpha) * len(c_prime)\n",
    "\n",
    "    return mask, k\n",
    "\n",
    "# this is parameter trained through backprop\n",
    "ratio = 0.90 # fixed value for this example \n",
    "topk_ratio = torch.nn.Parameter(torch.tensor(ratio, dtype=torch.float))\n",
    "pred_mask, k = calculate_topk_mask(topk_ratio)\n",
    "\n",
    "print(f'\\nTopk={k:.02f} singular will be selected, using predicted topk_ratio={alpha}')\n",
    "print(f'In the below tensor, the top {k:0.2f} values should be 1 and rest 0')\n",
    "print('\\npred_mask:\\n', pred_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4168385e",
   "metadata": {},
   "source": [
    "## Test NN Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c432ed96-032c-4067-a621-dae233784af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def compute_E_train_mask(E_train, UE):\n",
    "    # Initialize c_prime\n",
    "    c_prime = torch.linspace(1, 0, UE.shape[1], requires_grad=False, dtype=E_train.dtype)\n",
    "    c_prime.requires_grad_ = False\n",
    "\n",
    "    # Compute lambda_value\n",
    "    lambda_value = len(c_prime) + 1000\n",
    "\n",
    "    # Clamp E_train between 0 and 1\n",
    "    E_train2 = torch.clamp(E_train, min=0., max=1.)\n",
    "\n",
    "    # Compute E_train_mask\n",
    "    E_train_mask = torch.sigmoid(lambda_value * (c_prime - E_train2))\n",
    "\n",
    "    return E_train_mask\n",
    "\n",
    "# Example usage:\n",
    "UE = torch.randn(100, 500)\n",
    "E_train = torch.nn.Parameter(torch.zeros((1,), device=UE.device, dtype=UE.dtype))\n",
    "E_train_mask = compute_E_train_mask(E_train, UE)\n",
    "E_train_mask[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "148aed62-9f33-4adb-aeca-5e47767f4a6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.6889, 0.9641, 0.9896, 0.8598, 0.8937, 0.6957, 0.7350, 0.8949, 0.6708,\n",
       "         0.7791, 0.9477, 0.9342, 0.8948, 0.9996, 0.6775, 0.9568, 0.9029, 0.9979,\n",
       "         0.9971, 0.9350, 0.9960, 0.9188, 0.9860, 0.8756, 0.9969, 0.8882, 0.9654,\n",
       "         0.9646, 0.9996, 0.9935, 0.9917, 0.8736, 0.9597, 0.9956, 0.9989, 0.9987,\n",
       "         0.8825, 0.8811, 0.9020, 0.8000, 0.9293, 0.9905, 0.9471, 0.9251, 0.6947,\n",
       "         0.9742, 0.7911, 0.8579, 0.9273, 0.7817, 0.6796, 0.7451, 0.8763, 0.9146,\n",
       "         0.9832, 0.9653, 0.7622, 0.7551, 0.7119, 0.9985, 0.8506, 0.8755, 0.9999,\n",
       "         0.6270, 1.0000, 0.7161, 0.9986, 0.7204, 0.9497, 0.6563, 0.7509, 0.9541,\n",
       "         0.9999, 0.9910, 0.8930, 0.7092, 0.9698, 0.9973, 0.3730, 0.8664, 0.9985,\n",
       "         0.9207, 0.7152, 0.5185, 0.8944, 0.9617, 0.9322, 0.5974, 0.9141, 0.9564,\n",
       "         0.4668, 0.9884, 0.4657, 0.8275, 0.9959, 0.8169, 0.5619, 0.7061, 0.7129,\n",
       "         0.9723]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys \n",
    "sys.path.append('../')\n",
    "import os \n",
    "from utils.lowrank_modeling import gumbel_sigmoid\n",
    "\n",
    "mask = torch.ones(1, 100, dtype=torch.float32) * 1\n",
    "gumbel_sigmoid(mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bd16209",
   "metadata": {},
   "source": [
    "## Test gating layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c7dd5447",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.append('../')\n",
    "import os \n",
    "from utils import lowrank_modeling_v2\n",
    "import torch \n",
    "\n",
    "old_layer = torch.nn.Linear(768, 768*2)\n",
    "layer = lowrank_modeling_v2.LowrankLinearGate(old_layer, 0.4, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "529043c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> \u001b[0;32m/Users/sid/Documents/code/uni/thesis/learn-to-compress-svd/utils/lowrank_modeling_v2.py\u001b[0m(126)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m    124 \u001b[0;31m        \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    125 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m--> 126 \u001b[0;31m        \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    127 \u001b[0;31m        \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUE\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mE_train_mask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mV_t\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    128 \u001b[0;31m        \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\n",
      "tensor([0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950, 0.9950,\n",
      "        0.9950, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000], grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(2, 5, 768)\n",
    "layer(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f86aff42",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
