{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import argparse\n",
    "import jax\n",
    "import matplotlib.pyplot as plt\n",
    "import optax\n",
    "import matfree\n",
    "import tree_math as tm\n",
    "from flax import linen as nn\n",
    "from jax import nn as jnn\n",
    "from jax import numpy as jnp\n",
    "from jax import random, jit\n",
    "import pickle\n",
    "from src.losses import mse_loss\n",
    "from src.helper import calculate_exact_ggn, tree_random_normal_like\n",
    "from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive\n",
    "from jax import flatten_util\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.torch_datasets import MNIST, numpy_collate_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OOD dtasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.datasets import (\n",
    "    get_rotated_mnist_loaders,\n",
    "    get_rotated_fmnist_loaders,\n",
    "    get_rotated_cifar_loaders,\n",
    "    load_corrupted_cifar10,\n",
    "    load_corrupted_cifar10_per_type,\n",
    "    get_mnist_ood_loaders,\n",
    "    get_cifar10_ood_loaders,\n",
    "    get_cifar10_train_set,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples = 30#1000\n",
    "classes_train = [0,1,2,3,4,5,6,7,8,9]\n",
    "n_classes = 10\n",
    "batch_size = 20#256\n",
    "test_batch_size = 256\n",
    "\n",
    "data_train = MNIST(path_root= \"/xxx/data/\",\n",
    "            train=True, n_samples=train_samples if train_samples > 0 else None, cls=classes_train\n",
    "        )\n",
    "data_test = MNIST(path_root = \"/xxx/data/\", train=False, cls=classes_train)\n",
    "\n",
    "if train_samples > 0:\n",
    "    N = train_samples * n_classes\n",
    "else:\n",
    "    N = len(data_train)\n",
    "N_test = len(data_test)\n",
    "if test_batch_size > 0:\n",
    "    test_batch_size = test_batch_size\n",
    "else:\n",
    "    test_batch_size = len(data_test)\n",
    "\n",
    "n_test_batches = int(N_test / test_batch_size)\n",
    "n_batches = int(N / batch_size)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    data_train, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,\n",
    ")\n",
    "\n",
    "valid_loader = torch.utils.data.DataLoader(\n",
    "    data_test, batch_size=test_batch_size, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    output_dim: int = 10\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        if len(x.shape) != 4:\n",
    "            x = jnp.expand_dims(x, 0)\n",
    "        x = jnp.transpose(x, (0, 2, 3, 1))\n",
    "        x = nn.Conv(features=4, kernel_size=(3, 3), strides=(2, 2), padding=1)(x)\n",
    "        x = nn.tanh(x)\n",
    "        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "        x = nn.Conv(features=4, kernel_size=(3, 3), strides=(2, 2), padding=1)(x)\n",
    "        x = nn.tanh(x)\n",
    "        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "        x = x.reshape((x.shape[0], -1))\n",
    "        return nn.Dense(features=self.output_dim)(x)\n",
    "\n",
    "def compute_num_params(pytree):\n",
    "    return sum(x.size if hasattr(x, \"size\") else 0 for x in jax.tree_util.tree_leaves(pytree))\n",
    "\n",
    "\n",
    "model = ConvNet()\n",
    "batch = next(iter(train_loader))\n",
    "x_init, y_init = batch[\"image\"], batch[\"label\"]\n",
    "output_dim = y_init.shape[-1]\n",
    "key, split_key = random.split(jax.random.PRNGKey(0))\n",
    "params = model.init(key, x_init)\n",
    "alpha = 1.\n",
    "optim = optax.chain(\n",
    "        optax.clip(1.),\n",
    "        getattr(optax, \"adam\")(1e-2),\n",
    "    )\n",
    "opt_state = optim.init(params)\n",
    "n_params = compute_num_params(params)\n",
    "n_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_entropy_loss(preds, y, rho=1.0):\n",
    "    \"\"\"\n",
    "    preds: (n_samples, n_classes) (logits)\n",
    "    y: (n_samples, n_classes) (one-hot labels)\n",
    "    \"\"\"\n",
    "    preds = preds * rho\n",
    "    preds = jax.nn.log_softmax(preds, axis=-1)\n",
    "    return -jnp.sum(jnp.sum(preds * y, axis=-1))\n",
    "\n",
    "def accuracy(params, model, batch_x, batch_y):\n",
    "    preds = model.apply(params, batch_x)\n",
    "    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))\n",
    "\n",
    "\n",
    "def map_loss(\n",
    "    params,\n",
    "    model,\n",
    "    x_batch,\n",
    "    y_batch,\n",
    "    alpha,\n",
    "    n_params: int,\n",
    "    N_datapoints_max: int,\n",
    "):\n",
    "    # define dict for logging purposes\n",
    "    B = x_batch.shape[0]\n",
    "    O = y_batch.shape[-1]\n",
    "    D = n_params\n",
    "    N = N_datapoints_max\n",
    "\n",
    "    # hessian_scaler = 1\n",
    "\n",
    "    vparams = tm.Vector(params)\n",
    "\n",
    "    rho = 1.\n",
    "    nll = lambda x, y, rho: 1/B * cross_entropy_loss(x, y, rho)\n",
    "\n",
    "    y_pred = model.apply(params, x_batch)\n",
    "\n",
    "    loglike_loss = nll(y_pred, y_batch, rho) #* hessian_scaler\n",
    "\n",
    "    log_prior_term = -D / 2 * jnp.log(2 * jnp.pi) - (1 / 2) * alpha * (vparams @ vparams) + D / 2 * jnp.log(alpha)\n",
    "    # log_det_term = 0\n",
    "    loss = loglike_loss - 0. * log_prior_term\n",
    "\n",
    "    return loss\n",
    "\n",
    "def make_step(params, alpha, opt_state, x, y):\n",
    "    grad_fn = jax.value_and_grad(map_loss, argnums=0, has_aux=False)\n",
    "    loss, grads = grad_fn(params, model, x, y, alpha, n_params, N)\n",
    "    param_updates, opt_state = optim.update(grads, opt_state)\n",
    "    params = optax.apply_updates(params, param_updates)\n",
    "    return loss, params, opt_state\n",
    "\n",
    "jit_make_step = jit(make_step)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1, loss=34.942, , accuracy=0.12, alpha=1.00, time=1.226s\n",
      "epoch=2, loss=33.223, , accuracy=0.26, alpha=1.00, time=0.244s\n",
      "epoch=3, loss=30.935, , accuracy=0.36, alpha=1.00, time=0.244s\n",
      "epoch=4, loss=27.365, , accuracy=0.49, alpha=1.00, time=0.242s\n",
      "epoch=5, loss=24.140, , accuracy=0.53, alpha=1.00, time=0.242s\n",
      "epoch=6, loss=21.069, , accuracy=0.59, alpha=1.00, time=0.241s\n",
      "epoch=7, loss=18.855, , accuracy=0.66, alpha=1.00, time=0.242s\n",
      "epoch=8, loss=16.580, , accuracy=0.67, alpha=1.00, time=0.242s\n",
      "epoch=9, loss=14.824, , accuracy=0.74, alpha=1.00, time=0.241s\n",
      "epoch=10, loss=13.271, , accuracy=0.75, alpha=1.00, time=0.240s\n",
      "epoch=11, loss=12.114, , accuracy=0.77, alpha=1.00, time=0.240s\n",
      "epoch=12, loss=11.268, , accuracy=0.79, alpha=1.00, time=0.241s\n",
      "epoch=13, loss=10.514, , accuracy=0.82, alpha=1.00, time=0.240s\n",
      "epoch=14, loss=10.056, , accuracy=0.83, alpha=1.00, time=0.240s\n",
      "epoch=15, loss=9.443, , accuracy=0.84, alpha=1.00, time=0.251s\n",
      "epoch=16, loss=9.131, , accuracy=0.84, alpha=1.00, time=0.246s\n",
      "epoch=17, loss=8.813, , accuracy=0.85, alpha=1.00, time=0.246s\n",
      "epoch=18, loss=8.440, , accuracy=0.85, alpha=1.00, time=0.256s\n",
      "epoch=19, loss=8.135, , accuracy=0.86, alpha=1.00, time=0.239s\n",
      "epoch=20, loss=7.702, , accuracy=0.87, alpha=1.00, time=0.239s\n",
      "epoch=21, loss=7.317, , accuracy=0.88, alpha=1.00, time=0.240s\n",
      "epoch=22, loss=7.290, , accuracy=0.88, alpha=1.00, time=0.247s\n",
      "epoch=23, loss=7.152, , accuracy=0.88, alpha=1.00, time=0.238s\n",
      "epoch=24, loss=6.900, , accuracy=0.89, alpha=1.00, time=0.238s\n",
      "epoch=25, loss=6.615, , accuracy=0.91, alpha=1.00, time=0.237s\n",
      "epoch=26, loss=6.365, , accuracy=0.90, alpha=1.00, time=0.237s\n",
      "epoch=27, loss=6.137, , accuracy=0.91, alpha=1.00, time=0.237s\n",
      "epoch=28, loss=5.839, , accuracy=0.91, alpha=1.00, time=0.245s\n",
      "epoch=29, loss=5.771, , accuracy=0.92, alpha=1.00, time=0.238s\n",
      "epoch=30, loss=5.597, , accuracy=0.92, alpha=1.00, time=0.238s\n",
      "epoch=31, loss=5.387, , accuracy=0.92, alpha=1.00, time=0.245s\n",
      "epoch=32, loss=5.231, , accuracy=0.93, alpha=1.00, time=0.239s\n",
      "epoch=33, loss=5.085, , accuracy=0.93, alpha=1.00, time=0.239s\n",
      "epoch=34, loss=4.911, , accuracy=0.93, alpha=1.00, time=0.239s\n",
      "epoch=35, loss=4.795, , accuracy=0.92, alpha=1.00, time=0.239s\n",
      "epoch=36, loss=4.636, , accuracy=0.92, alpha=1.00, time=0.244s\n",
      "epoch=37, loss=4.502, , accuracy=0.95, alpha=1.00, time=0.245s\n",
      "epoch=38, loss=4.321, , accuracy=0.95, alpha=1.00, time=0.378s\n",
      "epoch=39, loss=4.241, , accuracy=0.94, alpha=1.00, time=0.240s\n",
      "epoch=40, loss=4.182, , accuracy=0.95, alpha=1.00, time=0.239s\n",
      "epoch=41, loss=3.959, , accuracy=0.95, alpha=1.00, time=0.239s\n",
      "epoch=42, loss=3.907, , accuracy=0.95, alpha=1.00, time=0.239s\n",
      "epoch=43, loss=3.783, , accuracy=0.96, alpha=1.00, time=0.240s\n",
      "epoch=44, loss=3.632, , accuracy=0.96, alpha=1.00, time=0.265s\n",
      "epoch=45, loss=3.511, , accuracy=0.97, alpha=1.00, time=0.301s\n",
      "epoch=46, loss=3.444, , accuracy=0.96, alpha=1.00, time=0.246s\n",
      "epoch=47, loss=3.319, , accuracy=0.97, alpha=1.00, time=0.245s\n",
      "epoch=48, loss=3.387, , accuracy=0.96, alpha=1.00, time=0.246s\n",
      "epoch=49, loss=3.164, , accuracy=0.97, alpha=1.00, time=0.239s\n",
      "epoch=50, loss=3.114, , accuracy=0.97, alpha=1.00, time=0.244s\n",
      "epoch=51, loss=3.045, , accuracy=0.97, alpha=1.00, time=0.240s\n",
      "epoch=52, loss=3.031, , accuracy=0.97, alpha=1.00, time=0.240s\n",
      "epoch=53, loss=3.019, , accuracy=0.96, alpha=1.00, time=0.239s\n",
      "epoch=54, loss=2.809, , accuracy=0.97, alpha=1.00, time=0.246s\n",
      "epoch=55, loss=2.742, , accuracy=0.98, alpha=1.00, time=0.238s\n",
      "epoch=56, loss=2.670, , accuracy=0.98, alpha=1.00, time=0.246s\n",
      "epoch=57, loss=2.725, , accuracy=0.98, alpha=1.00, time=0.239s\n",
      "epoch=58, loss=2.522, , accuracy=0.98, alpha=1.00, time=0.247s\n",
      "epoch=59, loss=2.455, , accuracy=0.98, alpha=1.00, time=0.240s\n",
      "epoch=60, loss=2.326, , accuracy=0.98, alpha=1.00, time=0.240s\n",
      "epoch=61, loss=2.359, , accuracy=0.98, alpha=1.00, time=0.246s\n",
      "epoch=62, loss=2.227, , accuracy=0.98, alpha=1.00, time=0.240s\n",
      "epoch=63, loss=2.179, , accuracy=0.98, alpha=1.00, time=0.240s\n",
      "epoch=64, loss=2.180, , accuracy=0.98, alpha=1.00, time=0.259s\n",
      "epoch=65, loss=2.094, , accuracy=0.98, alpha=1.00, time=0.280s\n",
      "epoch=66, loss=2.083, , accuracy=0.98, alpha=1.00, time=0.268s\n",
      "epoch=67, loss=1.960, , accuracy=0.99, alpha=1.00, time=0.266s\n",
      "epoch=68, loss=1.934, , accuracy=0.98, alpha=1.00, time=0.251s\n",
      "epoch=69, loss=1.964, , accuracy=0.98, alpha=1.00, time=0.262s\n",
      "epoch=70, loss=1.810, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=71, loss=1.769, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=72, loss=1.743, , accuracy=0.98, alpha=1.00, time=0.240s\n",
      "epoch=73, loss=1.672, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=74, loss=1.683, , accuracy=0.99, alpha=1.00, time=0.252s\n",
      "epoch=75, loss=1.630, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=76, loss=1.646, , accuracy=0.99, alpha=1.00, time=0.246s\n",
      "epoch=77, loss=1.593, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=78, loss=1.580, , accuracy=0.98, alpha=1.00, time=0.241s\n",
      "epoch=79, loss=1.497, , accuracy=0.99, alpha=1.00, time=0.240s\n",
      "epoch=80, loss=1.519, , accuracy=0.99, alpha=1.00, time=0.241s\n",
      "epoch=81, loss=1.449, , accuracy=0.99, alpha=1.00, time=0.241s\n",
      "epoch=82, loss=1.433, , accuracy=0.99, alpha=1.00, time=0.247s\n",
      "epoch=83, loss=1.353, , accuracy=0.99, alpha=1.00, time=0.238s\n",
      "epoch=84, loss=1.347, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=85, loss=1.350, , accuracy=0.99, alpha=1.00, time=0.239s\n",
      "epoch=86, loss=1.289, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=87, loss=1.241, , accuracy=0.99, alpha=1.00, time=0.239s\n",
      "epoch=88, loss=1.212, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=89, loss=1.196, , accuracy=0.99, alpha=1.00, time=0.238s\n",
      "epoch=90, loss=1.157, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=91, loss=1.153, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=92, loss=1.122, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=93, loss=1.139, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=94, loss=1.128, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=95, loss=1.077, , accuracy=1.00, alpha=1.00, time=0.239s\n",
      "epoch=96, loss=1.051, , accuracy=1.00, alpha=1.00, time=0.243s\n",
      "epoch=97, loss=1.076, , accuracy=1.00, alpha=1.00, time=0.238s\n",
      "epoch=98, loss=1.017, , accuracy=1.00, alpha=1.00, time=0.247s\n",
      "epoch=99, loss=1.001, , accuracy=1.00, alpha=1.00, time=0.254s\n",
      "epoch=100, loss=0.921, , accuracy=1.00, alpha=1.00, time=0.240s\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, n_epochs + 1):\n",
    "    epoch_loss = 0\n",
    "    epoch_accuracy = 0\n",
    "    start_time = time.time()\n",
    "    for _, batch in zip(range(n_batches), train_loader):\n",
    "        X = batch[\"image\"]\n",
    "        y = batch[\"label\"]\n",
    "        B = X.shape[0]\n",
    "        train_key, split_key = random.split(split_key)\n",
    "\n",
    "        loss, params, opt_state = jit_make_step(params, alpha, opt_state, X, y)\n",
    "        loss = loss\n",
    "        epoch_loss += loss.item()\n",
    "\n",
    "        epoch_accuracy += accuracy(params, model, X, y).item()\n",
    "\n",
    "    epoch_accuracy /= (n_batches * B)\n",
    "    epoch_time = time.time() - start_time\n",
    "    print(\n",
    "        f\"epoch={epoch}, loss={epoch_loss:.3f}, , accuracy={epoch_accuracy:.2f}, alpha={alpha:.2f}, time={epoch_time:.3f}s\"\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampling_train_loader = torch.utils.data.DataLoader(\n",
    "    data_train, batch_size=N, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,\n",
    ")\n",
    "data = next(iter(sampling_train_loader))\n",
    "x_train = jnp.array(data[\"image\"])\n",
    "y_train = jnp.array(data[\"label\"])\n",
    "sampling_val_loader = torch.utils.data.DataLoader(\n",
    "    data_test, batch_size=N_test, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,\n",
    ")\n",
    "data = next(iter(sampling_val_loader))\n",
    "x_val = jnp.array(data[\"image\"])\n",
    "y_val = jnp.array(data[\"label\"])\n",
    "\n",
    "sample_key = jax.random.PRNGKey(0)\n",
    "n_posterior_samples = 200\n",
    "num_iterations = 1\n",
    "n_sample_batch_size = 1\n",
    "n_sample_batches = N // n_sample_batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Laplace shit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.sampling.exact_ggn import exact_ggn_laplace\n",
    "from src.sampling.laplace_ode import ode_ggn\n",
    "from src.sampling.lanczos_diffusion import lanczos_diffusion\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "_model_fn = lambda params, x: model.apply(params, x[None, ...])[0]\n",
    "ggn = calculate_exact_ggn(cross_entropy_loss, _model_fn, params, x_train, y_train, n_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvals, eigvecs = jnp.linalg.eigh(ggn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 1.0\n",
    "rank = 100\n",
    "\n",
    "def ggn_lr_vp(v):\n",
    "    return eigvecs[:,-rank:] @ jnp.diag(1/jnp.sqrt(eigvals[-rank:]+ alpha)) @ v\n",
    "\n",
    "def ggn_vp(v):\n",
    "    return eigvecs @ jnp.diag(1/jnp.sqrt(eigvals + alpha)) @ v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_posterior_samples = 20\n",
    "D = compute_num_params(params)\n",
    "sample_key = jax.random.PRNGKey(0)\n",
    "eps = jax.random.normal(sample_key, (n_posterior_samples, D))\n",
    "p0_flat, unravel_func_p = flatten_util.ravel_pytree(params)\n",
    "var = 0.1\n",
    "def get_posteriors(single_eps):\n",
    "    lr_sample = unravel_func_p(ggn_lr_vp(single_eps[:rank]) + p0_flat)\n",
    "    posterior_sample = unravel_func_p(ggn_vp(single_eps) + p0_flat)\n",
    "    isotropic_sample = unravel_func_p(var * single_eps + p0_flat)\n",
    "    return lr_sample, posterior_sample, isotropic_sample\n",
    "lr_posterior_samples, posterior_samples, isotropic_posterior_samples = jax.vmap(get_posteriors)(eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampled Laplace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictive = sample_predictive(lr_posterior_samples, params, model, x_val, False, \"Pytree\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20, 10000, 10)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictive.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.6786, dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy(params, model, x_val, y_val)/x_val.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy_preds(preds, batch_y):\n",
    "    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))\n",
    "accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive, y_val)\n",
    "accuracies /= x_val.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.64703, dtype=float32)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jnp.mean(accuracies)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lanczos Diffusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.sampling.low_rank import lanczos_tridiag\n",
    "from typing import Callable, Literal, Optional\n",
    "from functools import partial\n",
    "\n",
    "def get_gvp_fun(\n",
    "    model_fn: Callable,\n",
    "    loss_fn: Callable,\n",
    "    params,\n",
    "    x,\n",
    "    y\n",
    "  ) -> Callable:\n",
    "\n",
    "  def gvp(eps):\n",
    "    def scan_fun(carry, batch):\n",
    "      x_, y_ = batch\n",
    "      fn = lambda p: model_fn(p,x_[None,:])\n",
    "      loss_fn_ = lambda preds: loss_fn(preds, y_)\n",
    "      out, Je = jax.jvp(fn, (params,), (eps,))\n",
    "      _, HJe = jax.jvp(jax.jacrev(loss_fn_, argnums=0), (out,), (Je,))\n",
    "      _, vjp_fn = jax.vjp(fn, params)\n",
    "      value = vjp_fn(HJe)[0]\n",
    "      return jax.tree_map(lambda c, v: c + v, carry, value), None\n",
    "    init_value = jax.tree_map(lambda x: jnp.zeros_like(x), params)\n",
    "    return jax.lax.scan(scan_fun, init_value, (x, y))[0]\n",
    "  p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)\n",
    "  def matvec(v_like_params):\n",
    "    p_unravelled = unravel_func_p(v_like_params)\n",
    "    ggn_vp = gvp(p_unravelled)\n",
    "    f_eval, _ = jax.flatten_util.ravel_pytree(ggn_vp)\n",
    "    return f_eval\n",
    "  return matvec\n",
    "\n",
    "gvp = get_gvp_fun(model.apply, mse_loss, params, x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 20\n",
    "n_samples = 50\n",
    "alpha = 10.0\n",
    "rank = 7\n",
    "eps = jax.random.normal(sample_key, (n_samples, n_steps, rank))\n",
    "p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)\n",
    "# rank = 100\n",
    "# alpha = 0.1\n",
    "v0 = jnp.ones(n_params)*5\n",
    "\n",
    "@jax.jit\n",
    "def rw_nonker(single_eps_path):\n",
    "    params_ = p0_flat\n",
    "    posterior_list = [params]\n",
    "    def body_fun(n, res):\n",
    "        gvp = get_gvp_fun(model.apply, mse_loss, unravel_func_p(res), x_train, y_train)\n",
    "        _, eigvecs = lanczos_tridiag(gvp, v0, rank - 1)\n",
    "        # ggn = calculate_exact_ggn(mse_loss, _model_fn, unravel_func_p(res), x_train, y_train, D)\n",
    "        # _, eigvecs = jnp.linalg.eigh(ggn)\n",
    "        lr_sample = 1/jnp.sqrt(alpha) * eigvecs @ single_eps_path[n]\n",
    "        params_ = res + 1/jnp.sqrt(n_steps) * lr_sample\n",
    "        return params_\n",
    "    v_ = jax.lax.fori_loop(0, n_steps - 1, body_fun, params_)\n",
    "    return unravel_func_p(v_)\n",
    "nonker_posterior_samples = jax.vmap(rw_nonker)(eps)#jax.lax.map(rw, eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 20\n",
    "n_samples = 50\n",
    "alpha = 10.0\n",
    "rank = 10\n",
    "eps = jax.random.normal(sample_key, (n_samples, n_steps, n_params))\n",
    "p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)\n",
    "# rank = 100\n",
    "# alpha = 0.1\n",
    "v0 = jnp.ones(n_params)*5\n",
    "delta = 1.0\n",
    "\n",
    "@jax.jit\n",
    "def rw_ker(single_eps_path):\n",
    "    params_ = p0_flat\n",
    "    posterior_list = [params]\n",
    "    def body_fun(n, res):\n",
    "        gvp = get_gvp_fun(model.apply, mse_loss, unravel_func_p(res), x_train, y_train)\n",
    "        gvp_ = lambda v: gvp(v) + delta * v\n",
    "        eigvals, eigvecs = lanczos_tridiag(gvp_, v0, rank - 1)\n",
    "        # ggn = calculate_exact_ggn(mse_loss, _model_fn, unravel_func_p(res), x_train, y_train, D)\n",
    "        # _, eigvecs = jnp.linalg.eigh(ggn)\n",
    "        lr_sample = 1/jnp.sqrt(alpha) * 1/delta * (gvp_(single_eps_path[n]) - eigvecs @ jnp.diag(eigvals) @ eigvecs.T @ single_eps_path[n])\n",
    "        params_ = res + 1/jnp.sqrt(n_steps) * lr_sample\n",
    "        return params_\n",
    "    v_ = jax.lax.fori_loop(0, n_steps - 1, body_fun, params_)\n",
    "    return unravel_func_p(v_)\n",
    "ker_posterior_samples = jax.vmap(rw_ker)(eps)#jax.lax.map(rw, eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 20\n",
    "n_samples = 50\n",
    "alpha = 10.0\n",
    "rank = 50\n",
    "nonker_posterior_samples = lanczos_diffusion(cross_entropy_loss, model.apply, params,n_steps,n_samples,alpha,sample_key,n_params,rank,x_train,y_train,1.0,\"non-kernel-eigvals\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6318\n",
      "0.40998\n"
     ]
    }
   ],
   "source": [
    "def accuracy_preds(preds, batch_y):\n",
    "    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))\n",
    "predictive_nonker = sample_predictive(nonker_posterior_samples, params, model, x_val, False, \"Pytree\")\n",
    "print(accuracy(params, model, x_val, y_val)/x_val.shape[0])\n",
    "nonker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_nonker, y_val)\n",
    "nonker_accuracies /= x_val.shape[0]\n",
    "print(jnp.mean(nonker_accuracies))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictive_ker = sample_predictive(ker_posterior_samples, params, model, x_val, False, \"Pytree\")\n",
    "predictive_nonker = sample_predictive(nonker_posterior_samples, params, model, x_val, False, \"Pytree\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.6318, dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy(params, model, x_val, y_val)/x_val.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.099659994\n",
      "0.40998\n"
     ]
    }
   ],
   "source": [
    "ker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_ker, y_val)\n",
    "ker_accuracies /= x_val.shape[0]\n",
    "nonker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_nonker, y_val)\n",
    "nonker_accuracies /= x_val.shape[0]\n",
    "print(jnp.mean(ker_accuracies))\n",
    "print(jnp.mean(nonker_accuracies))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geom",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
