{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "54ecf71a",
   "metadata": {},
   "source": [
    "# Hyperespherical energy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9a1b26f",
   "metadata": {},
   "source": [
    "In order to observe the error made by the inverse approximation, we conduct the following experiment "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25f0051e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def hyperespherical_energy(x: torch.Tensor, s: float = 1.0) -> torch.Tensor:\n",
    "\n",
    "    x = torch.nn.functional.normalize(x, dim=0, p=2)\n",
    "    G = x @ x.t()                                    \n",
    "    D2 = torch.clamp(2 - 2*G, min=1e-12)               \n",
    "    K  = D2.pow(-s/2)                                 \n",
    "\n",
    "    return K.triu(1).sum().mul(2).item()\n",
    "\n",
    "def generate_approx_orthogonal_matrix(r : int, dim : int, dist : str = 'gaussian'):\n",
    "\n",
    "    U = None\n",
    "\n",
    "    for i in range(r // 2):\n",
    "        v = torch.zeros((dim, 1))\n",
    "        torch.nn.init.normal_(v)\n",
    "\n",
    "        if U is None:\n",
    "            U = v.clone()\n",
    "        else:\n",
    "            U = torch.cat([U, v.clone()], dim=1)\n",
    "        \n",
    "        U = torch.cat([U, v.clone()], dim=1)\n",
    "\n",
    "    WU = torch.matmul(U.T, U)\n",
    "    DU = WU.diagonal().reciprocal().mul(2)\n",
    "    SU = DU * WU.triu(1) * DU\n",
    "    SU.diagonal().sub_(DU)\n",
    "    \n",
    "    return torch.eye(dim) + U @ SU @ U.T\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b00524d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results = []\n",
    "\n",
    "dims = [256, 512, 1024, 2048, 4096]\n",
    "ranks = [4, 8, 16, 32, 64, 128]\n",
    "\n",
    "trials = 100\n",
    "\n",
    "for dim in dims:\n",
    "    results.append([])\n",
    "    for r in ranks:\n",
    "\n",
    "        error = 0.0\n",
    "\n",
    "        for i in range(trials):\n",
    "\n",
    "            M = torch.randn((dim, dim)).normal_()\n",
    "\n",
    "            Q1  = generate_approx_orthogonal_matrix(r, dim)\n",
    "            Q2 = generate_approx_orthogonal_matrix(r, dim)\n",
    "\n",
    "            error += abs(hyperespherical_energy(M) - hyperespherical_energy(Q1 @ M @ Q2)) \n",
    "        \n",
    "        results[-1].append(error / trials)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0dff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "\n",
    "norm = plt.Normalize(min(dims), max(dims))\n",
    "colormap = cm.get_cmap(\"viridis\")\n",
    "\n",
    "\n",
    "for i in range(len(results)):\n",
    "    plt.plot(range(len(ranks)), np.log10(results[i]), label=f'{dims[i]}', color=colormap(norm(dims[-i - 1] * 1.7)))\n",
    "\n",
    "\n",
    "plt.xticks(range(len(ranks)), ranks)\n",
    "plt.xlabel('Rank')\n",
    "plt.yticks(range(0, 10, 2), ['1e0', '1e2', '1e4', '1e6', '1e8'])\n",
    "plt.ylabel('Hyperespherical energy difference')\n",
    "plt.legend()\n",
    "plt.savefig('energy_diff.pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
