{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "LiUs6wBfkabr",
    "outputId": "f73eca20-b08e-4d39-f17e-d1b11bcd5974"
   },
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "\n",
    "! nvidia-smi\n",
    "! set | grep CUDA_VISIBLE_DEVICES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "LiUs6wBfkabr",
    "outputId": "f73eca20-b08e-4d39-f17e-d1b11bcd5974"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import transformers\n",
    "\n",
    "model_name = \"JackFram/llama-160m\"\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(model_name).train(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xk6ux2Ioly-A",
    "outputId": "a4bb552c-1728-44ae-cf45-2b88856ba4a5"
   },
   "outputs": [],
   "source": [
    "inputs = tokenizer(\n",
    "    \"To estimate the L-smoothness of a language model in parater space, one must\",\n",
    "    return_tensors='pt'\n",
    ")\n",
    "logits = model(**inputs).logits.detach()  # softmax(logits, -1) is the probability of next token\n",
    "print('logits.shape: [num_sequences, sequence length(tokens), vocab_size] =', logits.shape)\n",
    "\n",
    "# generation example - just for your information\n",
    "for i in range(5):\n",
    "  with torch.no_grad():\n",
    "    logits = model(**inputs).logits.detach()\n",
    "  new_token: int = logits[0, -1].argmax().item()\n",
    "  new_sequence = tokenizer.decode(inputs['input_ids'][0].tolist() + [new_token], skip_special_tokens=True)\n",
    "  inputs = tokenizer(new_sequence, return_tensors='pt')\n",
    "  print(new_sequence)\n",
    "del logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-XX1R6iun1d6"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "inputs = tokenizer(\n",
    "    \"To estimate the L-smoothness of a language model in parater space, one must\",\n",
    "    return_tensors='pt'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "X5o1phpfmtOn",
    "outputId": "97748888-7a84-43c0-891d-dc5dab18e1ff"
   },
   "outputs": [],
   "source": [
    "# objective function\n",
    "logprobs = F.log_softmax(model(**inputs).logits, dim=-1)  # [batch, sequence length, vocab_size]\n",
    "\n",
    "loss_values = -(\n",
    "    logprobs[:, :-1] * F.one_hot(inputs['input_ids'][:, 1:], num_classes=logprobs.shape[-1])\n",
    ").sum(-1)  # [batch, sequence length]\n",
    "\n",
    "# Sanity check\n",
    "loss = loss_values.mean()\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bn7qewpSosWO"
   },
   "outputs": [],
   "source": [
    "# Sanity check\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3E4cgyo2oNVj",
    "outputId": "d7782d86-b07b-4637-bbea-2ef2c0fa13d8"
   },
   "outputs": [],
   "source": [
    "# the parameters that we need - the ones that we are going to quantize\n",
    "\n",
    "# one example: 5-th layer self-attention Q\n",
    "print(model.model.layers[5].self_attn.q_proj.weight)\n",
    "print('grad norm:', model.model.layers[5].self_attn.q_proj.weight.grad.norm())\n",
    "\n",
    "all_quantized_weights = set()\n",
    "for module in model.model.layers.modules():\n",
    "  if isinstance(module, nn.Linear):\n",
    "    all_quantized_weights.add(module.weight)\n",
    "assert model.model.layers[5].self_attn.q_proj.weight in all_quantized_weights\n",
    "print(f\"found {len(all_quantized_weights)} quantized weight tensors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zROHKkiVwxqg"
   },
   "outputs": [],
   "source": [
    "# META INFO ABOUT TRAINABLE VARIABLES\n",
    "def numberTrainableVariables(weights) -> int:\n",
    "    total_number_of_scalar_parameters = 0\n",
    "    for p in weights:\n",
    "        if p.requires_grad:\n",
    "            total_number_of_scalar_parameters += p.numel()\n",
    "    return total_number_of_scalar_parameters\n",
    "\n",
    "def trainableVariablesStaticInfo(weights) -> list:\n",
    "    static_info = []\n",
    "    for param in weights:\n",
    "        if param.requires_grad:\n",
    "            sz_in_bytes = param.numel() * param.element_size()\n",
    "            info = f\"DIMENSION: {str(list(param.shape)):15} | SCALARS: {str(param.numel()):10}| TYPE: {param.dtype} | SIZE: {sz_in_bytes/1024:.2f} KBYTES\"\n",
    "            static_info.append(info)\n",
    "\n",
    "    return static_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zROHKkiVwxqg"
   },
   "outputs": [],
   "source": [
    "# EXRTA UTILS\n",
    "def setIterateToZero(weights):\n",
    "    with torch.no_grad():\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                param.zero_()\n",
    "\n",
    "def getL2NormSqrOfIterate(weights) -> list:\n",
    "    result = 0.0\n",
    "    with torch.no_grad():\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                result += torch.sum(torch.square(param))\n",
    "    return result.item()\n",
    "\n",
    "def getIterateAsVector(weights):\n",
    "    params = []\n",
    "    with torch.no_grad():\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                params.append(param.flatten(0).detach())\n",
    "\n",
    "    params_vector = torch.cat(tuple(params))\n",
    "    return params_vector\n",
    "\n",
    "def setIterateAsVector(weights, parametersVec):\n",
    "    with torch.no_grad():\n",
    "        offset = 0\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                sz = param.numel()\n",
    "                param.flatten(0)[:] = parametersVec[(offset):(offset + sz)]\n",
    "                offset += sz\n",
    "\n",
    "def getL2NormSqrOfGradAtIterate(weights) -> list:\n",
    "    result = 0.0\n",
    "    with torch.no_grad():\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                result += torch.sum(torch.square(param.grad))\n",
    "    return result.item()\n",
    "\n",
    "def getGradientAsVector(weights):\n",
    "    grads = []\n",
    "    with torch.no_grad():\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                if param.grad is not None:\n",
    "                    grads.append(param.grad.flatten(0).detach())     # Remove clone()\n",
    "                else:\n",
    "                    grads.append(torch.zeros_like(p).flatten(0))\n",
    "    grad_vec = torch.cat(tuple(grads))\n",
    "    return grad_vec\n",
    "\n",
    "def setGradientAsVector(weights, gradVec):\n",
    "    with torch.no_grad():\n",
    "        offset = 0\n",
    "        for param in weights:\n",
    "            if param.requires_grad:\n",
    "                if param.grad is None:\n",
    "                    param.grad = torch.empty_like(p)\n",
    "                    \n",
    "                sz = param.grad.numel()\n",
    "                param.grad.flatten(0) [:] = gradVec[ (offset) : (offset + sz) ]\n",
    "                offset += sz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Additional helping function to work with markers\n",
    "def marker(i):\n",
    "    markers = ['v','^','<','>','s','p','P','*','h','H','x','X','D','d']\n",
    "    return markers[i % len(markers)]\n",
    "def color(i):\n",
    "    colors = [\"red\", \"blue\", \"orange\", 'green', \"pink\"]\n",
    "    return colors[i % len(colors)]\n",
    "\n",
    "def plot(X, Y, title = \"Test\", label=None, xlabel=None, ylabel=None):\n",
    "    plt.figure(title, dpi=72*2)\n",
    "    # These are the style setting I use for my projects\n",
    "    kwargs = {'linewidth': 2,\n",
    "              'markersize': 6,\n",
    "              'markeredgecolor': 'black',\n",
    "              'markeredgewidth': 1.0,\n",
    "              'label': label,\n",
    "              'marker': marker(0),\n",
    "              'markevery': 1,\n",
    "              'c': color(0)}\n",
    "    plt.plot(X, Y, **kwargs)\n",
    "    plt.xlabel(xlabel, fontsize=16)\n",
    "    plt.ylabel(ylabel, fontsize=16)\n",
    "    plt.xticks(fontsize=14)\n",
    "    plt.yticks(fontsize=14)\n",
    "    \n",
    "    if label:\n",
    "        plt.legend(loc='upper left')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_before_train = getIterateAsVector(all_quantized_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#===================================================================\n",
    "# EPOCHS/LR\n",
    "epochs = 10\n",
    "learning_rate = 0.0001\n",
    "#===================================================================\n",
    "# ESTIMATE L-smooth for all weights\n",
    "sampled_fractions=[1.0,0.85,0.7,0.6,0.4,0.3,0.2,0.1,0.05]\n",
    "sampled_fractions=[1.0]\n",
    "L_for_epochs_result = {}\n",
    "weights_dimension = {}\n",
    "\n",
    "for k in range(len(sampled_fractions)):\n",
    "    frozen=set()\n",
    "    for param in all_quantized_weights:\n",
    "        param.requires_grad_(True)\n",
    "    \n",
    "    setIterateAsVector(all_quantized_weights, x_before_train)\n",
    "\n",
    "    sampled_fraction = sampled_fractions[k]\n",
    "    print(\"Processing fraction: \", sampled_fraction)\n",
    "    gen = np.random.RandomState(seed = 123)\n",
    "    \n",
    "    for param in all_quantized_weights:\n",
    "        sel = gen.random()\n",
    "        if sel < sampled_fraction:\n",
    "            param.requires_grad_(True)\n",
    "        else:\n",
    "            param.requires_grad_(False)\n",
    "            frozen.add(param)\n",
    "    #===================================================================\n",
    "    g=[]\n",
    "    x=[]\n",
    "    for i in range(epochs):\n",
    "        logprobs = F.log_softmax(model(**inputs).logits, dim=-1)  # [batch, sequence length, vocab_size]\n",
    "        loss_values = -(logprobs[:, :-1] * F.one_hot(inputs['input_ids'][:, 1:], num_classes=logprobs.shape[-1])).sum(-1)  # [batch, sequence length]\n",
    "        loss = loss_values.mean()\n",
    "        loss.backward()\n",
    "        print(f\"Loss. Epoch {i}: {loss.item():.3f}\")   \n",
    "\n",
    "        g.append(getGradientAsVector(all_quantized_weights))\n",
    "        x.append(getIterateAsVector(all_quantized_weights))\n",
    "    \n",
    "        with torch.no_grad():\n",
    "            for param in all_quantized_weights:\n",
    "                if param.requires_grad:\n",
    "                    param -= learning_rate * param.grad\n",
    "        model.zero_grad()\n",
    "    #===================================================================\n",
    "\n",
    "    L_for_epochs = []\n",
    "    for i in range(len(x)):\n",
    "        L_compute_all = [0]\n",
    "\n",
    "        if i>= 1:\n",
    "            L_compute_all.append(max(L_for_epochs[i-1]))\n",
    "        \n",
    "        print(\"compute L for epoch: \", i)\n",
    "\n",
    "\n",
    "        x_len = torch.linalg.norm(x[i])\n",
    "        r = g[i]\n",
    "        mult = 1e-5*torch.linalg.norm(x[i])\n",
    "\n",
    "        for ii in range(10):\n",
    "            r_len = torch.linalg.norm(r)\n",
    "            r = r/r_len\n",
    "\n",
    "            setIterateAsVector(all_quantized_weights, x[i] + r*mult)    \n",
    "            logprobs = F.log_softmax(model(**inputs).logits, dim=-1)  # [batch, sequence length, vocab_size]\n",
    "            loss_values = -(logprobs[:, :-1] * F.one_hot(inputs['input_ids'][:, 1:], num_classes=logprobs.shape[-1])).sum(-1)  # [batch, sequence length]\n",
    "            loss = loss_values.mean()\n",
    "            loss.backward()\n",
    "            g_shift = getGradientAsVector(all_quantized_weights)\n",
    "            r = 1.0/mult * (g_shift - g[i])\n",
    "            model.zero_grad()\n",
    "\n",
    "        L_compute_all.append(torch.linalg.norm(r).item())\n",
    "        L_for_epochs.append(L_compute_all)\n",
    "    \n",
    "    L_for_epochs_result[k] = [max(L_for_epochs[i]) for i in range(len(L_for_epochs))]\n",
    "    weights_dimension[k] = x[0].numel()\n",
    "    print(sampled_fraction, \":\", L_for_epochs_result[k][-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "L_for_epochs_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(dpi=72*2)\n",
    "\n",
    "for k in range(len(sampled_fractions)):\n",
    "    kwargs = {'linewidth': 2,\n",
    "              'markersize': 6,\n",
    "              'markeredgecolor': 'black',\n",
    "              'markeredgewidth': 1.0,\n",
    "              'label': f\"  {sampled_fractions[k]*100}% Linear Layers Variables ($d:$ {weights_dimension[k]/10**6:.2f}M)\",\n",
    "              'marker': marker(k),\n",
    "              'markevery': 1,\n",
    "              'c': color(k)}\n",
    "    L=L_for_epochs_result[k]\n",
    "    plt.plot(range(1,epochs), L[1:], **kwargs)\n",
    "    print(sampled_fractions[k], \":\", L_for_epochs_result[k])\n",
    "\n",
    "plt.xlabel(\"Iterations\", fontsize=16)\n",
    "plt.ylabel(\"$L_f$\", fontsize=16)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "\n",
    "plt.legend()\n",
    "fname=f\"example-plot-{model_name.split('/')[1]}.pdf\"\n",
    "plt.savefig(fname, format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "print(\"PLOT IS SAVED TO:\", fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
