{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad6b7208",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "095928a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import safetensors\n",
    "with safetensors.safe_open(\"./model_embs_cache/meta-llama__Llama-3.2-1B.safetensors\", framework=\"pt\") as f:\n",
    "    llama_embs = f.get_tensor(\"embs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18938979",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 10), dpi=100)\n",
    "\n",
    "plt.imshow(\n",
    "    llama_embs.T.corrcoef().numpy(),\n",
    "    cmap='viridis'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2940d15b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express\n",
    "\n",
    "plotly.express.histogram(\n",
    "    (llama_embs.T.corrcoef().abs() - torch.eye(llama_embs.shape[1])).flatten(),\n",
    "    nbins=100,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "124acfbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multicolinear_corrcoef(x: torch.Tensor) -> torch.Tensor:\n",
    "    corrs = []\n",
    "    n, d = x.shape\n",
    "    for i in range(d):\n",
    "        x_i = x[:, i]\n",
    "        x_rest = torch.cat([x[:, :i], x[:, i+1:]], dim=1)\n",
    "        beta = torch.linalg.lstsq(x_i.unsqueeze(1), x_rest).solution\n",
    "        x_i_pred = x_rest @ beta.T\n",
    "        corrs.append(torch.corrcoef(torch.stack([x_i, x_i_pred.flatten()]))[0, 1].item())\n",
    "    return torch.tensor(corrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d001d33",
   "metadata": {},
   "outputs": [],
   "source": [
    "multicolinear_corrcoef(llama_embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4eb6ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.bar(\n",
    "    multicolinear_corrcoef(llama_embs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "731c9210",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly.express.histogram(\n",
    "    multicolinear_corrcoef(llama_embs)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
