{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "54ecf71a",
   "metadata": {},
   "source": [
    "# Approximation error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9a1b26f",
   "metadata": {},
   "source": [
    "In order to observe the error made by the inverse approximation, we conduct the following experiment "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25f0051e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def generate_approx_orthogonal_matrix(r : int, dim : int, dist : str = 'gaussian'):\n",
    "\n",
    "    U = torch.zeros((dim, r))\n",
    "\n",
    "    if dist == 'gaussian':\n",
    "        torch.nn.init.normal_(U)\n",
    "\n",
    "\n",
    "    WU = torch.matmul(U.T, U)\n",
    "    DU = WU.diagonal().reciprocal().mul(2)\n",
    "    SU = DU * WU.triu(1) * DU\n",
    "    SU.diagonal().sub_(DU)\n",
    "    \n",
    "    return torch.eye(dim) + U @ SU @ U.T\n",
    "\n",
    "\n",
    "results = []\n",
    "\n",
    "dims = [256, 512, 1024, 2048, 4096]\n",
    "ranks = [1, 2, 4, 8, 16, 32, 64, 128]\n",
    "\n",
    "trials = 100\n",
    "\n",
    "for dim in dims:\n",
    "    results.append([])\n",
    "    for r in ranks:\n",
    "\n",
    "        error = 0.0\n",
    "\n",
    "        for i in range(trials):\n",
    "\n",
    "            Q  = generate_approx_orthogonal_matrix(r, dim)\n",
    "            error += torch.norm(torch.eye(dim) - Q @ Q.T) / (dim **.5)\n",
    "        \n",
    "        results[-1].append((error / trials).item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91f38165",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "\n",
    "norm = plt.Normalize(min(dims), max(dims))\n",
    "colormap = cm.get_cmap(\"viridis\")\n",
    "\n",
    "plt.figure(figsize=(6.5, 4))\n",
    "\n",
    "for i in range(len(results)):\n",
    "    plt.plot(range(len(ranks)), np.log10(results[i]), label=f'{dims[i]}', color=colormap(norm(dims[-i - 1] * 1.7)))\n",
    "\n",
    "plt.xticks(range(len(ranks)), ranks)\n",
    "plt.yticks(range(-7, 1), [f'1e-{n}' for n in range(7, -1, -1)])\n",
    "plt.xlabel('Rank')\n",
    "plt.ylabel('Error')\n",
    "plt.legend()\n",
    "plt.savefig('error.pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
