{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79888b9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import plotly.express\n",
    "import torch\n",
    "import sklearn.decomposition\n",
    "from torch import Tensor\n",
    "\n",
    "def pca(embs: Tensor, low_dim: int) -> Tensor:\n",
    "    pca = sklearn.decomposition.PCA(n_components=low_dim)\n",
    "    reduced_embs = pca.fit_transform(embs.detach().numpy())\n",
    "    return torch.tensor(reduced_embs)\n",
    "\n",
    "def fourier(embs: Tensor) -> Tensor:\n",
    "    return torch.fft.fft(embs, dim=0)\n",
    "\n",
    "def vis_emb_plotly(embs: Tensor, title: str) -> plotly.graph_objects.Figure:\n",
    "    fig = plotly.express.imshow(\n",
    "        embs.cpu().T.detach(),\n",
    "        color_continuous_scale=\"blues\",\n",
    "        aspect='auto',\n",
    "    )\n",
    "    return fig.update_layout(title=title).update_xaxes(title=\"Token Value\").update_yaxes(title=\"Feature\")\n",
    "\n",
    "llama1b = transformers.AutoModel.from_pretrained(\"meta-llama/Llama-3.2-1B\").eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B\")\n",
    "llama1b = llama1b.eval()\n",
    "\n",
    "inputs_str = [\n",
    "    f\"{x} + 1 = \" for x in range(0, 1000)\n",
    "]\n",
    "inputs = tokenizer(inputs_str, return_tensors=\"pt\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = llama1b(**inputs, output_hidden_states=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9dd44a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for hidden_state_layer_idx in [10, 11, 12, 13]:\n",
    "    hidden_states_last_token = outputs.hidden_states[hidden_state_layer_idx][:, -1, :]\n",
    "    display(vis_emb_plotly(\n",
    "        pca(hidden_states_last_token.cpu(), 16),\n",
    "        f\"hidden states in layer {hidden_state_layer_idx} of LLama 1B\"\n",
    "    ))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
