{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e67a7619",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import plotly.express"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "560045cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dim in [100, 10000]:\n",
    "    for vocab_size in [10000, 100000]:\n",
    "        torch.manual_seed(1)\n",
    "        embs = torch.normal(mean=0, std=1, size=(vocab_size, dim))\n",
    "        idx_1, idx_2 = 1000, 2000\n",
    "        logits = (0.5 * embs[idx_1] + 0.5 * embs[idx_2]) @ embs.T\n",
    "        probs = logits.softmax(dim=-1)\n",
    "        probs, idcs = probs.topk(10, sorted=True)\n",
    "        plotly.express.bar(y=probs, x=list(map(str, idcs.tolist())), title=f\"Random Embs top10 (vocab={vocab_size}, dim={dim})\", text_auto=\".6f\").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c039ad9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import tqdm\n",
    "import polars as pl\n",
    "df = []\n",
    "dims = [50, 100, 500, 1000, 5000, 10000]\n",
    "vocab = [1000, 5000, 10000, 50000, 100000]\n",
    "\n",
    "for dim, vocab_size in tqdm.tqdm(itertools.product(dims, vocab), total=len(dims) * len(vocab)):\n",
    "        torch.manual_seed(42)\n",
    "        embs = torch.normal(mean=0, std=1, size=(vocab_size, dim))\n",
    "        embs /= embs.norm(dim=-1, keepdim=True)\n",
    "        idx_1, idx_2 = 400, 600\n",
    "        logits = (0.5 * embs[idx_1] + 0.5 * embs[idx_2]) @ embs.T\n",
    "        probs = logits.softmax(dim=-1)\n",
    "        median = probs.median()\n",
    "        probs, idcs = probs.topk(10, sorted=True)\n",
    "        gap_diff = probs[1] - probs[2]\n",
    "        gap_ratio = probs[1] / probs[2]\n",
    "        df.append(\n",
    "            {\n",
    "                \"dim\": dim,\n",
    "                \"vocab\": vocab_size,\n",
    "                \"gap_diff\": gap_diff.item(),\n",
    "                \"gap_ratio\": gap_ratio.item(),\n",
    "                \"top1\": probs[0].item(),\n",
    "                \"top2\": probs[1].item(),\n",
    "                \"top3\": probs[2].item(),\n",
    "                \"median\": median.item()\n",
    "            } \n",
    "        )\n",
    "df = pl.DataFrame(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "595b790c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.line(\n",
    "    df,\n",
    "    x=\"dim\",\n",
    "    y=\"gap_ratio\",\n",
    "    color=\"vocab\",\n",
    "    markers=True,\n",
    "    title=\"Probability ratio between the correct tokens and most probably incorrect token\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "594bde62",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.line(\n",
    "    df,\n",
    "    x=\"vocab\",\n",
    "    y=\"gap_ratio\",\n",
    "    color=\"dim\",\n",
    "    markers=True\n",
    ")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd73d497",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.line(\n",
    "    df,\n",
    "    x=\"dim\",\n",
    "    y=\"gap_diff\",\n",
    "    color=\"vocab\",\n",
    "    markers=True,\n",
    "    title=\"Difference in probability between the correct tokens and most probably incorrect token\"\n",
    ")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b084fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.line(\n",
    "    df,\n",
    "    x=\"vocab\",\n",
    "    y=\"gap_diff\",\n",
    "    color=\"dim\",\n",
    "    markers=True\n",
    ")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "967a59b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dim in [100, 1000, 10000]:\n",
    "    for vocab_size in [1000, 10000, 100000]:\n",
    "        torch.manual_seed(1)\n",
    "        embs = torch.normal(mean=0, std=1, size=(vocab_size, dim))\n",
    "        embs /= embs.norm(dim=-1, keepdim=True)\n",
    "        idx_1, idx_2 = 400, 600\n",
    "        logits = (0.5 * embs[idx_1] + 0.5 * embs[idx_2]) @ embs.T\n",
    "        probs = logits.softmax(dim=-1)\n",
    "\n",
    "        probs, idcs = probs.topk(10, sorted=True)\n",
    "        plotly.express.bar(y=probs, x=list(map(str, idcs.tolist())), title=f\"Random Normalized Embs topk (vocab={vocab_size}, dim={dim})\", text_auto=\".6f\").show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7dddbd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "\n",
    "for vocab_size in [1000, 100000]:\n",
    "    torch.manual_seed(0)\n",
    "    embs = model.embed_tokens.weight[torch.randint(0, model.config.vocab_size, (vocab_size,))].detach()\n",
    "    idx_1, idx_2 = 400, 600\n",
    "    logits = (0.5 * embs[idx_1] + 0.5 * embs[idx_2]) @ embs.T\n",
    "    probs = logits.softmax(dim=-1)\n",
    "    probs, idcs = probs.topk(10, sorted=True)\n",
    "    plotly.express.bar(y=probs, x=list(map(str, idcs.tolist())), title=f\"LLama 1B - topk (vocab_size={vocab_size}, dim={embs.shape[-1]})\", text_auto=\".6f\").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc81f4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "model_ckpt = \"meta-llama/Llama-3.1-70B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "\n",
    "for vocab_size in [1000, 100000]:\n",
    "    torch.manual_seed(0)\n",
    "    embs = model.embed_tokens.weight[torch.randint(0, model.config.vocab_size, (vocab_size,))].detach()\n",
    "    idx_1, idx_2 = 400, 600\n",
    "    logits = (0.5 * embs[idx_1] + 0.5 * embs[idx_2]) @ embs.T\n",
    "    probs = logits.softmax(dim=-1)\n",
    "    probs, idcs = probs.topk(10, sorted=True)\n",
    "    plotly.express.bar(y=probs, x=list(map(str, idcs.tolist())), title=f\"LLama 70B - topk (vocab_size={vocab_size}, dim={embs.shape[-1]})\", text_auto=\".6f\").show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
