{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc107dd6",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79888b9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import transformers\n",
    "import plotly.express\n",
    "import plotly.graph_objects\n",
    "import matplotlib\n",
    "import PIL.Image\n",
    "import numpy as np\n",
    "import torch\n",
    "import sklearn.decomposition\n",
    "from torch import Tensor\n",
    "\n",
    "def pca(embs: Tensor, low_dim: int) -> Tensor:\n",
    "    pca = sklearn.decomposition.PCA(n_components=low_dim)\n",
    "    reduced_embs = pca.fit_transform(embs.detach().numpy())\n",
    "    return torch.tensor(reduced_embs)\n",
    "\n",
    "def fourier(embs: Tensor) -> Tensor:\n",
    "    return torch.fft.fft(embs, dim=0)\n",
    "\n",
    "\n",
    "def vis_emb(embs: Tensor, colorful: bool) -> PIL.Image.Image:\n",
    "    x = embs.cpu().detach()\n",
    "    x_normalized = (x - x.min()) / (x.max() - x.min())\n",
    "    theme = matplotlib.colormaps[\"Blues\"]\n",
    "    if colorful:\n",
    "        x_normalized = theme(x_normalized)\n",
    "    else:\n",
    "        x_normalized = x_normalized.numpy()\n",
    "    vis = PIL.Image.fromarray((x_normalized * 255).astype(np.uint8))\n",
    "    return vis\n",
    "\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "\n",
    "inputs_str = [\n",
    "    f\"{x1} + {x2}\" for x1, x2 in zip(random.Random(0).sample(range(0, 1000), 1000), range(0, 1000))\n",
    "]\n",
    "inputs = tokenizer(inputs_str, return_tensors=\"pt\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs, output_hidden_states=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9dd44a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for hidden_state_layer_idx in range(len(outputs.hidden_states)):\n",
    "    hidden_states_last_token = outputs.hidden_states[hidden_state_layer_idx][:, -1, :]\n",
    "    repr_pca = pca(hidden_states_last_token.cpu(), 64)\n",
    "    repr_fft = fourier(repr_pca).abs().T.max(dim=0).values\n",
    "    print(f\"Layer {hidden_state_layer_idx} / {len(outputs.hidden_states)}\")\n",
    "    display(\n",
    "        vis_emb(repr_pca.T[:16], colorful=True).resize((800, 200), PIL.Image.Resampling.NEAREST)\n",
    "    )\n",
    "    display(\n",
    "        plotly.express.bar(\n",
    "            x=torch.arange(len(repr_fft)),\n",
    "            y=repr_fft.cpu().detach(),\n",
    "            color_discrete_sequence=[\"black\"],\n",
    "        ).update_layout(\n",
    "            showlegend=False,\n",
    "            width=1200,\n",
    "            height=400,\n",
    "            margin=dict(l=0, r=8, t=0, b=0),\n",
    "            bargap=0,\n",
    "        ).update_xaxes(\n",
    "            #title=\"Frequency\",\n",
    "            title=None,\n",
    "            dtick=100,\n",
    "            tickfont=dict(size=28),\n",
    "        ).update_yaxes(\n",
    "            #title=model.display_name,\n",
    "            title=None,\n",
    "            title_font=dict(size=32),\n",
    "            tickfont=dict(size=28),\n",
    "            showticklabels=False,\n",
    "        ).update_traces(\n",
    "            marker_line_width=0\n",
    "        )\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
