{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import argparse\n",
    "\n",
    "import jax\n",
    "import matplotlib.pyplot as plt\n",
    "import optax\n",
    "import tree_math as tm\n",
    "from flax import linen as nn\n",
    "from jax import nn as jnn\n",
    "from jax import numpy as jnp\n",
    "from jax import random, jit\n",
    "import pickle\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.sampling.invsqrt_vp import inv_sqrt_vp\n",
    "from src.sampling.low_rank import lanczos_tridiag\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(10)\n",
    "params = 500\n",
    "rank = 300\n",
    "J = jax.random.normal(key, (rank, params))\n",
    "prior_prec = 5.\n",
    "GGN_lr = J.T @ J\n",
    "GGN = J.T @ J + prior_prec * jnp.eye(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = 299\n",
    "Av = lambda v: GGN_lr @ v\n",
    "v0 = jnp.ones(params)\n",
    "eigvals, eigvecs = lanczos_tridiag(Av, v0, order)\n",
    "recon = eigvecs @ jnp.diag(eigvals) @ eigvecs.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstruction Relative Error: 0.05795053765177727\n"
     ]
    }
   ],
   "source": [
    "recon_error = jnp.linalg.norm(GGN_lr - recon)/jnp.linalg.norm(GGN_lr) \n",
    "print(\"Reconstruction Relative Error: {}\".format(recon_error))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mvp_fn = inv_sqrt_vp(eigvals, eigvecs, prior_prec)\n",
    "\n",
    "def test_fn(test_vec):\n",
    "    mvp_approx = mvp_fn(test_vec)\n",
    "    ggn_sqrt = jax.scipy.linalg.sqrtm(GGN)\n",
    "    gt = jnp.linalg.solve(ggn_sqrt, test_vec)\n",
    "    return jnp.linalg.norm(mvp_approx - gt)/jnp.linalg.norm(gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test 0 relative error:0.184461772441864\n",
      "Test 1 relative error:0.20036917924880981\n",
      "Test 2 relative error:0.17739072442054749\n",
      "Test 3 relative error:0.17245149612426758\n",
      "Test 4 relative error:0.18886618316173553\n"
     ]
    }
   ],
   "source": [
    "num_tests = 5\n",
    "test_vecs = [jax.random.normal(k, (params,)) * var for k, var in zip(jax.random.split(key, num_tests), jnp.arange(1, num_tests + 1))]\n",
    "for i, test_vec in enumerate(test_vecs):\n",
    "    relative_error = test_fn(test_vec)\n",
    "    print(\"Test {} relative error:{}\".format(i,relative_error))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geom",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
