{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import config \n",
    "config.update(\"jax_enable_x64\", True)\n",
    "\n",
    "\n",
    "import jax \n",
    "import jax.numpy as jnp\n",
    "import gpjax \n",
    "from gpjax.typing import Float, ScalarFloat\n",
    "from jaxtyping import Num \n",
    "from gpjax.base import static_field, param_field, Module\n",
    "from jax.tree_util import Partial\n",
    "import tensorflow_probability.substrates.jax.distributions as tfd\n",
    "from jaxtyping import Key\n",
    "\n",
    "from matplotlib import pyplot as plt \n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "\n",
    "def array(x):\n",
    "    return jnp.array(x, dtype=jnp.float64)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def sph_to_car(sph):\n",
    "    \"\"\"\n",
    "    From spherical (colat, lon) coordinates to cartesian, single point.\n",
    "    \"\"\"\n",
    "    colat, lon = sph[..., 0], sph[..., 1]\n",
    "    z = jnp.cos(colat)\n",
    "    r = jnp.sin(colat)\n",
    "    x = r * jnp.cos(lon)\n",
    "    y = r * jnp.sin(lon)\n",
    "    return jnp.stack([x, y, z], axis=-1)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def car_to_sph(car):\n",
    "    x, y, z = car[..., 0], car[..., 1], car[..., 2]\n",
    "    colat = jnp.arccos(z)\n",
    "    lon = jnp.arctan2(y, x)\n",
    "    lon = (lon + 2 * jnp.pi) % (2 * jnp.pi)\n",
    "    return jnp.stack([colat, lon], axis=-1)\n",
    "\n",
    "\n",
    "\n",
    "from pathlib import Path\n",
    "from typing import Callable\n",
    "\n",
    "import numpy as np\n",
    "from jax import Array\n",
    "\n",
    "\n",
    "class FundamentalSystemNotPrecomputedError(ValueError):\n",
    "\n",
    "    def __init__(self, dimension: int):\n",
    "        message = f\"Fundamental system for dimension {dimension} has not been precomputed.\"\n",
    "        super().__init__(message)\n",
    "\n",
    "\n",
    "def fundamental_set_loader(dimension: int, load_dir=\"fundamental_system\") -> Callable[[int], Array]:\n",
    "    load_dir = Path(\"../\") / load_dir\n",
    "    file_name = load_dir / f\"fs_{dimension}D.npz\"\n",
    "\n",
    "    cache = {}\n",
    "    if file_name.exists():\n",
    "        with np.load(file_name) as f:\n",
    "            cache = {k: v for (k, v) in f.items()}\n",
    "    else:\n",
    "        raise FundamentalSystemNotPrecomputedError(dimension)\n",
    "\n",
    "    def load(degree: int) -> Array:\n",
    "        key = f\"degree_{degree}\"\n",
    "        if key not in cache:\n",
    "            raise ValueError(f\"key: {key} not in cache.\")\n",
    "        return cache[key]\n",
    "\n",
    "    return load\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))\n",
    "def gegenbauer(x: Float[Array, \"N D\"], max_ell: int, alpha: float = 0.5) -> Float[Array, \"N L\"]:\n",
    "    \"\"\"\n",
    "    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.\n",
    "\n",
    "    Cᵅ₀(x) = 1\n",
    "    Cᵅ₁(x) = 2αx\n",
    "    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n\n",
    "\n",
    "    Args:\n",
    "        level: The order of the polynomial.\n",
    "        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.\n",
    "        x: Input array.\n",
    "\n",
    "    Returns:\n",
    "        The Gegenbauer polynomial evaluated at `x`.\n",
    "    \"\"\"\n",
    "    C_0 = jnp.ones_like(x, dtype=x.dtype)\n",
    "    C_1 = 2 * alpha * x\n",
    "    \n",
    "    res = jnp.empty((*x.shape, max_ell + 1), dtype=x.dtype)\n",
    "    res = res.at[..., 0].set(C_0)\n",
    "\n",
    "    def step(n: int, res_and_Cs: tuple[Float, Float, Float]) -> tuple[Float, Float, Float]:\n",
    "        res, C, C_prev = res_and_Cs\n",
    "        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C\n",
    "        res = res.at[..., n].set(C)\n",
    "        return res, C, C_prev\n",
    "    \n",
    "    return jax.lax.cond(\n",
    "        max_ell == 0,\n",
    "        lambda: res,\n",
    "        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[..., 1].set(C_1), C_1, C_0))[0],\n",
    "    )\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call \n",
    "def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:\n",
    "    \"\"\"\n",
    "    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.\n",
    "\n",
    "    Cᵅ₀(x) = 1\n",
    "    Cᵅ₁(x) = 2αx\n",
    "    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n\n",
    "\n",
    "    Args:\n",
    "        level: The order of the polynomial.\n",
    "        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.\n",
    "        x: Input array.\n",
    "\n",
    "    Returns:\n",
    "        The Gegenbauer polynomial evaluated at `x`.\n",
    "    \"\"\"\n",
    "    C_0 = jnp.ones_like(x, dtype=x.dtype)\n",
    "    C_1 = 2 * alpha * x\n",
    "\n",
    "    def step(Cs_and_n):\n",
    "        C, C_prev, n = Cs_and_n\n",
    "        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C\n",
    "        return C, C_prev, n + 1\n",
    "\n",
    "    def cond(Cs_and_n):\n",
    "        n = Cs_and_n[2]\n",
    "        return n <= ell\n",
    "\n",
    "    return jax.lax.cond(\n",
    "        ell == 0,\n",
    "        lambda: C_0,\n",
    "        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],\n",
    "    )\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class SphericalHarmonics(gpjax.Module):\n",
    "    \"\"\"\n",
    "    Spherical harmonics inducing features for sparse inference in Gaussian processes.\n",
    "\n",
    "    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,\n",
    "    as such, they form an orthogonal basis.\n",
    "\n",
    "    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute\n",
    "    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via\n",
    "    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during\n",
    "    initialisation.\n",
    "\n",
    "    Attributes:\n",
    "        num_frequencies: The number of frequencies, up to which, we compute the harmonics.\n",
    "\n",
    "    Returns:\n",
    "        An instance of the spherical harmonics features.\n",
    "    \"\"\"\n",
    "\n",
    "    max_ell: int = static_field()\n",
    "    sphere_dim: int = static_field()\n",
    "    alpha: float = static_field(init=False)\n",
    "    orth_basis: Array = static_field(init=False)\n",
    "    Vs: list[Array] = static_field(init=False)\n",
    "    num_phases_per_frequency: Float[Array, \" L\"] = static_field(init=False)\n",
    "    num_phases: int = static_field(init=False)\n",
    "\n",
    "\n",
    "    @property\n",
    "    def levels(self):\n",
    "        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)\n",
    "    \n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        \"\"\"\n",
    "        Initialise the parameters of the spherical harmonic features and return a `Param` object.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        dim = self.sphere_dim + 1\n",
    "\n",
    "        # Try loading a pre-computed fundamental set.\n",
    "        fund_set = fundamental_set_loader(dim)\n",
    "\n",
    "        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.\n",
    "        self.alpha = (dim - 2.0) / 2.0\n",
    "\n",
    "        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.\n",
    "        self.Vs = [fund_set(n) for n in self.levels]\n",
    "\n",
    "        # pre-compute and save the orthogonal basis \n",
    "        self.orth_basis = self._orthogonalise_basis()\n",
    "\n",
    "\n",
    "        # set these things instead of computing every time \n",
    "        self.num_phases_per_frequency = [v.shape[0] for v in self.Vs]\n",
    "        self.num_phases = sum(self.num_phases_per_frequency)\n",
    "\n",
    "\n",
    "    @property\n",
    "    def Ls(self) -> list[Array]:\n",
    "        \"\"\"\n",
    "        Alias for the orthogonal basis at every frequency.\n",
    "        \"\"\"\n",
    "        return self.orth_basis\n",
    "\n",
    "    def _orthogonalise_basis(self) -> None:\n",
    "        \"\"\"\n",
    "        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.\n",
    "        \"\"\"\n",
    "        alpha = self.alpha\n",
    "        levels = jnp.split(self.levels, self.max_ell + 1)\n",
    "        const = alpha / (alpha + self.levels.astype(jnp.float64))\n",
    "        const = jnp.split(const, self.max_ell + 1)\n",
    "\n",
    "        def _func(v, n, c):\n",
    "            x = jnp.matmul(v, v.T)\n",
    "            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)\n",
    "            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))\n",
    "            return L\n",
    "\n",
    "        return jax.tree.map(_func, self.Vs, levels, const)\n",
    "\n",
    "    def custom_gegenbauer_single(self, x, ell, alpha):\n",
    "        return gegenbauer(x, self.max_ell, alpha)[..., ell]\n",
    "\n",
    "    @jax.jit\n",
    "    def polynomial_expansion(self, X: Float[Array, \"N D\"]) -> Float[Array, \"M N\"]:\n",
    "        \"\"\"\n",
    "        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.\n",
    "\n",
    "        Args:\n",
    "            X: Input Array.\n",
    "\n",
    "        Returns:\n",
    "            The harmonics evaluated at the input as a polynomial expansion of the basis.\n",
    "        \"\"\"\n",
    "        levels = jnp.split(self.levels, self.max_ell + 1)\n",
    "\n",
    "        def _func(v, n, L):\n",
    "            vxT = jnp.dot(v, X.T)\n",
    "            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)\n",
    "            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)\n",
    "            return harmonic\n",
    "\n",
    "        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)\n",
    "        return jnp.concatenate(harmonics, axis=0)\n",
    "    \n",
    "    def __eq__(self, other: \"SphericalHarmonics\") -> bool:\n",
    "        \"\"\"\n",
    "        Check if two spherical harmonic features are equal.\n",
    "\n",
    "        Args:\n",
    "            other: The other spherical harmonic features.\n",
    "\n",
    "        Returns:\n",
    "            A boolean indicating if the two features are equal.\n",
    "        \"\"\"\n",
    "        # Given the first two parameters, the rest are deterministic. \n",
    "        # The user must not mutate all other fields, but that is not enforced for now.\n",
    "        return (\n",
    "            self.max_ell == other.max_ell \n",
    "            and self.sphere_dim == other.sphere_dim \n",
    "        )    \n",
    "\n",
    "def angles_to_radians_colat(x: Array) -> Array:\n",
    "    return jnp.pi * x / 180 + jnp.pi / 2\n",
    "\n",
    "def angles_to_radians_lon(x: Array) -> Array:\n",
    "    return jnp.pi * x / 180 \n",
    "\n",
    "\n",
    "from gpjax.base import static_field, param_field\n",
    "from gpjax.kernels import AbstractKernel\n",
    "from gpjax.likelihoods import AbstractLikelihood\n",
    "from gpjax.gps import AbstractPosterior\n",
    "import tensorflow_probability.substrates.jax.bijectors as tfb\n",
    "from jax.scipy.special import gammaln\n",
    "from jaxtyping import Int\n",
    "\n",
    "\n",
    "@jax.jit \n",
    "def comb(N, k) -> Int:\n",
    "    return jnp.round(jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))).astype(jnp.int64)\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=(\"sphere_dim\"))\n",
    "def num_phases_in_frequency(sphere_dim: int, frequency: Int) -> Int:\n",
    "    l, d = frequency, sphere_dim\n",
    "    return jnp.where(\n",
    "        l == 0, \n",
    "        jnp.ones_like(l, dtype=jnp.int64), \n",
    "        comb(l + d - 2, l - 1) + comb(l + d - 1, l),\n",
    "    )\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=(\"max_ell\", \"sphere_dim\"))\n",
    "def sphere_addition_theorem(x: Float[Array, \"D\"], y: Float[Array, \"D\"], *, max_ell: int, sphere_dim: int) -> Float:\n",
    "    alpha = (sphere_dim - 1) / 2.0\n",
    "    c1 = num_phases_in_frequency(sphere_dim=sphere_dim, frequency=jnp.arange(max_ell + 1))\n",
    "    c2 = gegenbauer(1.0, max_ell=max_ell, alpha=alpha)\n",
    "    Pz = gegenbauer(jnp.dot(x, y), max_ell=max_ell, alpha=alpha)\n",
    "    return c1 / c2 * Pz\n",
    "\n",
    "\n",
    "def addition_theorem_scalar_kernel(spectrum: Float[Array, \"I\"], z: Float[Array, \"I\"]) -> Float[Array, \"\"]:\n",
    "    return jnp.dot(spectrum, z)\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('dim',))\n",
    "def matern_spectrum(ell: Float, kappa: Float, nu: Float, variance: Float, dim: int) -> Float:\n",
    "    lambda_ells = ell * (ell + dim - 1)\n",
    "    log_Phi_nu_ells = -(nu + dim / 2) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))\n",
    "    \n",
    "    # Subtract max value for numerical stability\n",
    "    max_log_Phi = jnp.max(log_Phi_nu_ells)\n",
    "    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)\n",
    "    \n",
    "    # Normalize the density, so that it sums to 1\n",
    "    num_harmonics_per_ell = num_phases_in_frequency(frequency=ell, sphere_dim=dim)\n",
    "    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)\n",
    "    return variance * Phi_nu_ells / normalizer\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class SphereMaternKernel(Module):\n",
    "    sphere_dim: int = static_field(2)\n",
    "    kappa: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())\n",
    "    nu: ScalarFloat = param_field(jnp.array(1.5), bijector=tfb.Softplus())\n",
    "    variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())\n",
    "    max_ell: int = static_field(25)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)\n",
    "        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)\n",
    "        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)\n",
    "\n",
    "    @property \n",
    "    def ells(self):\n",
    "        return jnp.arange(self.max_ell + 1, dtype=jnp.float64)\n",
    "    \n",
    "    def spectrum(self) -> Num[Array, \"I\"]:\n",
    "        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)\n",
    "\n",
    "    @jax.jit \n",
    "    def from_spectrum(self, spectrum: Float[Array, \"M\"], x: Float[Array, \"D\"], y: Float[Array, \"D\"]) -> Float[Array, \"\"]:\n",
    "        return addition_theorem_scalar_kernel(\n",
    "            spectrum, \n",
    "            sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)\n",
    "        )\n",
    "    \n",
    "    @jax.jit \n",
    "    def __call__(self, x: Float[Array, \"D\"], y: Float[Array, \"D\"]) -> Float[Array, \"\"]:\n",
    "        return self.from_spectrum(self.spectrum(), x, y)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class MultioutputSphereMaternKernel(Module):\n",
    "    num_outputs: int = static_field()\n",
    "    sphere_dim: int = static_field(2)\n",
    "    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())\n",
    "    nu: ScalarFloat = param_field(jnp.array([1.5]), bijector=tfb.Softplus())\n",
    "    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())\n",
    "    max_ell: int = static_field(25)\n",
    "\n",
    "    def _validate_params(self) -> None:\n",
    "        # float64 for numerical stability\n",
    "        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)\n",
    "        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)\n",
    "        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)\n",
    "\n",
    "        # shape for multioutput\n",
    "        self.kappa = jnp.broadcast_to(self.kappa, (self.num_outputs,))\n",
    "        self.nu = jnp.broadcast_to(self.nu, (self.num_outputs,))\n",
    "        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))\n",
    "\n",
    "    def __post_init__(self):\n",
    "        self._validate_params()\n",
    "\n",
    "    @property \n",
    "    def ells(self):\n",
    "        return jnp.arange(self.max_ell + 1)\n",
    "    \n",
    "    @jax.jit \n",
    "    def spectrum(self) -> Num[Array, \"O L\"]:\n",
    "        return jax.vmap(\n",
    "            lambda kappa, nu, variance: matern_spectrum(self.ells, kappa, nu, variance, dim=self.sphere_dim)\n",
    "        )(self.kappa, self.nu, self.variance)\n",
    "    \n",
    "    @jax.jit \n",
    "    def from_spectrum(self, spectrum: Float[Array, \"O L\"], x: Float[Array, \"D\"], y: Float[Array, \"D\"]) -> Float[Array, \"O\"]:\n",
    "        return jax.vmap(\n",
    "            lambda spectrum: addition_theorem_scalar_kernel(\n",
    "                spectrum, \n",
    "                sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)\n",
    "            )\n",
    "        )(spectrum)\n",
    "    \n",
    "    @jax.jit \n",
    "    def __call__(self, x: Float[Array, \"D\"], y: Float[Array, \"D\"]) -> Float[Array, \"O\"]:\n",
    "        return self.from_spectrum(self.spectrum(), x, y)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class MultioutputPrior(Module):\n",
    "    kernel: MultioutputSphereMaternKernel = param_field()\n",
    "    jitter: Float = static_field(1e-12)\n",
    "\n",
    "    @property \n",
    "    def num_outputs(self):\n",
    "        return self.kernel.num_outputs\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class Prior(Module):\n",
    "    kernel: SphereMaternKernel = param_field()\n",
    "    jitter: Float = static_field(1e-12)\n",
    "    \n",
    "\n",
    "@dataclass\n",
    "class Posterior(Module):\n",
    "    prior: Prior = param_field()\n",
    "    likelihood: Module = param_field()\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class MultioutputPosterior(Module):\n",
    "    prior: MultioutputPrior = param_field()\n",
    "    likelihood: Module = param_field()\n",
    "\n",
    "    @property \n",
    "    def num_outputs(self) -> int:\n",
    "        return self.prior.num_outputs\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('jitter',))\n",
    "def spherical_harmonic_features_moments(\n",
    "    Kxz: Float[Array, \"M\"], \n",
    "    Kzz_inv_diag: Float[Array, \"M\"], \n",
    "    m: Float[Array, \"M\"], \n",
    "    sqrtS: Float[Array, \"M M\"], \n",
    "    jitter: float = 1e-12\n",
    ") -> tuple[Float[Array, \"\"], Float[Array, \"\"]]:\n",
    "    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)\n",
    "    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag\n",
    "    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS\n",
    "\n",
    "    covariance = (\n",
    "        jnp.sum(jnp.square(Kxz_Lzz_T_inv_sqrtS))\n",
    "        # + Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T\n",
    "        # - Kxz_Lzz_T_inv @ Kxz_Lzz_T_inv.T\n",
    "        # No need for the term above as it is absorbed into Kxx \n",
    "    )\n",
    "\n",
    "    mean = (\n",
    "        Kxz_Lzz_T_inv @ m\n",
    "    )\n",
    "\n",
    "    return mean, covariance\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('jitter',))\n",
    "def pathwise_sample_spherical_harmonic_features_posterior(\n",
    "    Kxz: Float[Array, \"M\"],\n",
    "    Kzz_inv_diag: Float[Array, \"M\"],\n",
    "    m: Float[Array, \"M\"],\n",
    "    sqrtS: Float[Array, \"M M\"],\n",
    "    jitter: float = 1e-12,\n",
    "    *, \n",
    "    key: Key\n",
    ") -> Float[Array, \"\"]:\n",
    "    u = jax.random.normal(key=key, shape=m.shape)\n",
    "\n",
    "    # f(x) + Kxz Kzz^{-1} (u - f(z)) = Kxz Kzz^{-1} u\n",
    "    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)\n",
    "    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag\n",
    "    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS\n",
    "    return Kxz_Lzz_T_inv_sqrtS @ u + Kxz_Lzz_T_inv @ m\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def whitened_prior_kl(m: Float, sqrtS: Float) -> Float:\n",
    "    S = sqrtS @ sqrtS.T\n",
    "    qz = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=S)\n",
    "\n",
    "    pz = tfd.MultivariateNormalFullCovariance(\n",
    "        loc=jnp.zeros(m.shape), \n",
    "        covariance_matrix=jnp.eye(m.shape[0]),\n",
    "    )\n",
    "    return tfd.kl_divergence(qz, pz)\n",
    "\n",
    "\n",
    "def inducing_points_prior_kl(m: Float, sqrtS: Float) -> Float:\n",
    "    return whitened_prior_kl(m, sqrtS)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class DummyPosterior(Module):\n",
    "    prior: Prior = param_field()\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class MultioutputDummyPosterior(Module):\n",
    "    prior: MultioutputPrior = param_field()\n",
    "\n",
    "    @property \n",
    "    def num_outputs(self):\n",
    "        return self.prior.num_outputs\n",
    "    \n",
    "\n",
    "@dataclass\n",
    "class SphericalHarmonicFeaturesPosterior(Module):\n",
    "    posterior: Posterior = param_field()\n",
    "    # spherical_harmonics: SphericalHarmonics = static_field()\n",
    "    spherical_harmonics: SphericalHarmonics = static_field()\n",
    "    m: Float[Array, \"M\"] = param_field(init=False)\n",
    "    sqrtS: Float[Array, \"M M\"] = param_field(init=False, bijector=tfb.FillTriangular())\n",
    "    num_inducing: int = static_field(init=False)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        self.num_inducing = self.spherical_harmonics.num_phases\n",
    "        self.m = jnp.zeros(self.num_inducing)\n",
    "        self.sqrtS = jnp.eye(self.num_inducing)\n",
    "\n",
    "    @jax.jit \n",
    "    def Kzz_diag(self, spectrum: Float[Array, \"L\"]) -> Float[Array, \"M\"]:\n",
    "        shf = self.spherical_harmonics\n",
    "        repeats = np.array(shf.num_phases_per_frequency)\n",
    "        total_repeat_length = shf.num_phases\n",
    "        return jnp.repeat(\n",
    "            spectrum[:shf.max_ell + 1], \n",
    "            repeats=repeats,\n",
    "            total_repeat_length=total_repeat_length,\n",
    "        )\n",
    "    \n",
    "    def Kxz(self, x: Float[Array, \"D\"]) -> Float[Array, \"M\"]:\n",
    "        return self.spherical_harmonics.polynomial_expansion(x).T\n",
    "    \n",
    "    def prior_kl(self) -> Float[Array, \"\"]:\n",
    "        return whitened_prior_kl(self.m, self.sqrtS)\n",
    "\n",
    "    @jax.jit\n",
    "    def moments(self, x: Float[Array, \"N D\"]) -> tuple[Float[Array, \"\"], Float[Array, \"\"]]:\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        spectrum = kernel.spectrum()\n",
    "\n",
    "        Kzz_diag = self.Kzz_diag(spectrum)\n",
    "        Kxz = self.Kxz(x)\n",
    "\n",
    "        return spherical_harmonic_features_moments(Kxz, Kzz_diag, self.m, self.sqrtS)\n",
    "    \n",
    "    @jax.jit \n",
    "    def diag(self, x: Float[Array, \"N D\"]) -> tfd.Normal:\n",
    "        mean, variance = jax.vmap(self.moments)(x)\n",
    "        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))\n",
    "    \n",
    "    @jax.jit \n",
    "    def pathwise_sample_single(self, x: Float[Array, \"D\"], *, key: Key) -> Float[Array, \"N\"]:\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        Kxz = self.Kxz(x)\n",
    "        Kzz_diag = self.Kzz_diag(kernel.spectrum())\n",
    "        return pathwise_sample_spherical_harmonic_features_posterior(\n",
    "            Kxz, Kzz_diag, self.m, self.sqrtS, key=key\n",
    "        )\n",
    "    \n",
    "    @jax.jit\n",
    "    def pathwise_sample(self, x: Float[Array, \"N D\"], *, key: Key) -> Float[Array, \"N\"]:\n",
    "        return jax.vmap(lambda x: self.pathwise_sample_single(x, key=key))(x)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class MultioutputSphericalHarmonicFeaturesPosterior(Module):\n",
    "    num_outputs: int = static_field(init=False)\n",
    "\n",
    "    posterior: MultioutputPosterior = param_field()\n",
    "    spherical_harmonics: SphericalHarmonics = static_field()\n",
    "    m: Float[Array, \"M\"] = param_field(init=False)\n",
    "    sqrtS: Float[Array, \"M M\"] = param_field(init=False, bijector=tfb.FillTriangular())\n",
    "    sqrtS_augment: Float[Array, \"L\"] = param_field(init=False)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        self.num_outputs = self.posterior.num_outputs\n",
    "        \n",
    "        num_inducing = self.spherical_harmonics.num_phases\n",
    "        self.m = jnp.zeros(num_inducing)\n",
    "        self.sqrtS = jnp.eye(num_inducing)\n",
    "        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)\n",
    "\n",
    "        self.m = jnp.broadcast_to(self.m, (self.num_outputs, num_inducing))\n",
    "        self.sqrtS = jnp.broadcast_to(self.sqrtS, (self.num_outputs, num_inducing, num_inducing))\n",
    "        self.sqrtS_augment = jnp.broadcast_to(self.sqrtS_augment, (self.num_outputs, kernel.max_ell + 1))\n",
    "\n",
    "    @jax.jit\n",
    "    def prior_kl(self) -> Float:\n",
    "        return jnp.sum(jax.vmap(whitened_prior_kl)(self.m, self.sqrtS), axis=0)\n",
    "\n",
    "    @jax.jit \n",
    "    def Kzz_diag(self, spectrum: Float[Array, \"O L\"]) -> Float[Array, \"O M\"]:\n",
    "        shf = self.spherical_harmonics\n",
    "        repeats = np.array(shf.num_phases_per_frequency)\n",
    "        total_repeat_length = shf.num_phases\n",
    "        return jax.vmap(\n",
    "            lambda spectrum: jnp.repeat(spectrum, repeats=repeats, total_repeat_length=total_repeat_length)\n",
    "        )(spectrum[:, :shf.max_ell + 1])\n",
    "    \n",
    "\n",
    "    def Kxz(self, x: Float[Array, \"D\"]) -> Float[Array, \"O M\"]:\n",
    "        return self.spherical_harmonics.polynomial_expansion(x).T\n",
    "    \n",
    "    \n",
    "    @jax.jit\n",
    "    def moments(self, x: Float[Array, \"D\"]) -> tuple[Float[Array, \"O\"], Float[Array, \"O\"]]:\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        # prior covariance adjusted by the diagonal variational parameters \n",
    "        spectrum = kernel.spectrum() # [O L]\n",
    "        S_augment = jnp.square(self.sqrtS_augment) # [O L]\n",
    "        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x) # [O N N]\n",
    "\n",
    "        # variational covariance \n",
    "        Kzz_diag = self.Kzz_diag(spectrum) # [O M]\n",
    "        Kxz = self.Kxz(x) # [O M]\n",
    "\n",
    "        m = self.m\n",
    "        sqrtS = self.sqrtS\n",
    "\n",
    "        return jax.vmap(\n",
    "            lambda Kxx, Kzz_diag, m, sqrtS: spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, m, sqrtS)\n",
    "        )(Kxx, Kzz_diag, m, sqrtS)\n",
    "    \n",
    "    @jax.jit \n",
    "    def diag(self, x: Float[Array, \"N D\"]) -> tfd.Normal:\n",
    "        mean, variance = jax.vmap(self.moments)(x)\n",
    "        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))\n",
    "    \n",
    "    @jax.jit\n",
    "    def pathwise_sample_single(self, x: Float[Array, \"D\"], *, key: Key) -> Float[Array, \"O\"]:\n",
    "        output_dim_keys = jax.random.split(key, self.num_outputs)\n",
    "\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        Kxz = self.Kxz(x)\n",
    "        Kzz_diag = self.Kzz_diag(kernel.spectrum())\n",
    "\n",
    "        return jax.vmap(\n",
    "            lambda Kzz_diag, m, sqrtS, key: pathwise_sample_spherical_harmonic_features_posterior(\n",
    "                Kxz, Kzz_diag, m, sqrtS, key=key\n",
    "        ))(Kzz_diag, self.m, self.sqrtS, output_dim_keys)\n",
    "    \n",
    "    @jax.jit \n",
    "    def pathwise_sample(self, x: Float[Array, \"N D\"], *, key: Key) -> Float[Array, \"N O\"]:\n",
    "        return jax.vmap(lambda x: self.pathwise_sample_single(x, key=key))(x)\n",
    "\n",
    "\n",
    "# TODO verify that this is correct \n",
    "@jax.jit\n",
    "def sphere_expmap(x: Float[Array, \"N D\"], v: Float[Array, \"N D\"]) -> Float[Array, \"N D\"]:\n",
    "    theta = jnp.linalg.norm(v, axis=-1, keepdims=True)\n",
    "\n",
    "    t = x + v\n",
    "    first_order_approx = t / jnp.linalg.norm(t, axis=-1, keepdims=True)\n",
    "    true_expmap = jnp.cos(theta) * x + jnp.sin(theta) * v / theta\n",
    "\n",
    "    return jnp.where(\n",
    "        theta < 1e-12,\n",
    "        first_order_approx,\n",
    "        true_expmap,\n",
    "    )\n",
    "\n",
    "\n",
    "@jax.jit \n",
    "def sphere_to_tangent(x: Float[Array, \"N D\"], v: Float[Array, \"N D\"]) -> Float[Array, \"N D\"]:\n",
    "    v_x = jnp.sum(x * v, axis=-1, keepdims=True)\n",
    "    return v - v_x * x\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class SphereResidualDeepGP(Module):\n",
    "    hidden_layers: list[MultioutputSphericalHarmonicFeaturesPosterior] = param_field()\n",
    "    output_layer: SphericalHarmonicFeaturesPosterior = param_field()\n",
    "    num_samples: int = static_field(1)\n",
    "\n",
    "    @property \n",
    "    def posterior(self) -> Posterior:\n",
    "        return self.output_layer.posterior      \n",
    "    \n",
    "    def prior_kl(self) -> Float:\n",
    "        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()\n",
    "    \n",
    "    def sample_moments(self, x: Float[Array, \"N D\"], *, key: Key) -> tfd.Normal:\n",
    "        hidden_layer_keys = jax.random.split(key, len(self.hidden_layers))\n",
    "        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):\n",
    "            v = layer.diag(x).sample(seed=hidden_layer_key)\n",
    "            u = sphere_to_tangent(x, v)\n",
    "            x = sphere_expmap(x, u)\n",
    "        return jax.vmap(self.output_layer.moments)(x)\n",
    "\n",
    "    def diag(self, x: Float[Array, \"N D\"], *, key: Key) -> tfd.MixtureSameFamily:\n",
    "        sample_keys = jax.random.split(key, self.num_samples)\n",
    "\n",
    "        # In MixtureSameFamily batch size goes last; hence, out_axes = 1\n",
    "        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) \n",
    "\n",
    "        return tfd.MixtureSameFamily(\n",
    "            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), \n",
    "            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), \n",
    "        )\n",
    "    \n",
    "    def pathwise_sample(self, x: Float[Array, \"N D\"], *, key: Key) -> Float[Array, \"N\"]:\n",
    "        hidden_layer_keys = jax.random.split(key, len(self.hidden_layers))\n",
    "        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):\n",
    "            v = layer.pathwise_sample(x, key=hidden_layer_key)\n",
    "            u = sphere_to_tangent(x, v)\n",
    "            x = sphere_expmap(x, u)\n",
    "        return self.output_layer.pathwise_sample(x, key=key)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DeepGaussianLikelihood(Module):\n",
    "    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())\n",
    "    \n",
    "    @jax.jit \n",
    "    def diag(self, pf: tfd.MixtureSameFamily) -> tfd.MixtureSameFamily:\n",
    "        component_distribution = pf.components_distribution\n",
    "        mean, variance = component_distribution.mean(), component_distribution.variance()\n",
    "        variance += self.noise_variance\n",
    "        return tfd.MixtureSameFamily(\n",
    "            mixture_distribution=pf.mixture_distribution,\n",
    "            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)),\n",
    "        )\n",
    "\n",
    "\n",
    "def create_residual_deep_gp_with_spherical_harmonic_features(\n",
    "    num_layers: int, total_hidden_variance: float, max_ell: int, x: Float[Array, \"N D\"], num_samples: int = 3, *, \n",
    "    nu: float = 2.5\n",
    ") -> SphereResidualDeepGP:\n",
    "    sphere_dim = x.shape[1] - 1\n",
    "\n",
    "    hidden_nu = jnp.array(nu)\n",
    "    output_nu = hidden_nu\n",
    "\n",
    "    hidden_variance = jnp.array(total_hidden_variance / max(num_layers - 1, 1))\n",
    "    output_variance = jnp.array(1.0)\n",
    "\n",
    "    hidden_kappa = jnp.array(1.0)\n",
    "    output_kappa = hidden_kappa\n",
    "\n",
    "    shf_max_ell = kernel_max_ell = max_ell\n",
    "    hidden_spherical_harmonics = SphericalHarmonics(max_ell=shf_max_ell, sphere_dim=sphere_dim)\n",
    "    output_spherical_harmonics = hidden_spherical_harmonics\n",
    "\n",
    "    hidden_layers = []\n",
    "    for _ in range(num_layers - 1):\n",
    "        kernel = MultioutputSphereMaternKernel(\n",
    "            num_outputs=sphere_dim + 1, \n",
    "            sphere_dim=sphere_dim, \n",
    "            nu=hidden_nu,\n",
    "            kappa=hidden_kappa,\n",
    "            variance=hidden_variance,\n",
    "            max_ell=kernel_max_ell,\n",
    "        )\n",
    "        prior = MultioutputPrior(kernel=kernel)\n",
    "        posterior = MultioutputDummyPosterior(prior=prior)\n",
    "        layer = MultioutputSphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=hidden_spherical_harmonics)\n",
    "        hidden_layers.append(layer)\n",
    "\n",
    "    kernel = SphereMaternKernel(\n",
    "        sphere_dim=sphere_dim,\n",
    "        nu=output_nu,\n",
    "        kappa=output_kappa,\n",
    "        variance=output_variance,\n",
    "        max_ell=kernel_max_ell,\n",
    "    )\n",
    "    prior = Prior(kernel=kernel)\n",
    "    likelihood = DeepGaussianLikelihood()\n",
    "    posterior = Posterior(prior=prior, likelihood=likelihood)\n",
    "    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)\n",
    "\n",
    "    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notes on data\n",
    "- tangent vectors need not be unit norm "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "from plotly import graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "\n",
    "\n",
    "mean_inputs = pd.read_csv(\"../mean_inputs.csv\", header=None, names=['x', 'y', 'z'])\n",
    "mean_outputs = pd.read_csv(\"../mean_outputs.csv\", header=None, names=['x', 'y', 'z', 'u', 'v', 'w'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = jnp.asarray(mean_inputs.values)\n",
    "model = create_residual_deep_gp_with_spherical_harmonic_features(\n",
    "    num_layers=5, total_hidden_variance=0.5, max_ell=10, x=mean_inputs.values, num_samples=1, nu=1.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(12)\n",
    "\n",
    "\n",
    "# Sphere (background)\n",
    "theta = jnp.linspace(0, 2 * jnp.pi, 100)\n",
    "phi = jnp.linspace(0, jnp.pi, 100)\n",
    "theta, phi = jnp.meshgrid(theta, phi)\n",
    "sphere_inputs = jnp.stack([jnp.sin(phi) * jnp.cos(theta), jnp.sin(phi) * jnp.sin(theta), jnp.cos(phi)], axis=-1)\n",
    "sphere_outputs = jnp.zeros((100, 100))\n",
    "\n",
    "\n",
    "# Ambient vector field \n",
    "v = model.hidden_layers[0].pathwise_sample(x, key=key)\n",
    "ambient_inputs = jnp.concat([x, v], axis=-1)\n",
    "ambient_outputs = jnp.zeros((v.shape[0], ))\n",
    "\n",
    "# projected vector field\n",
    "u = sphere_to_tangent(x, v)\n",
    "projected_inputs = jnp.concat([x, u], axis=-1)\n",
    "projected_outputs = jnp.zeros((u.shape[0], ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot all images \n",
    "r_scatter = 1.01\n",
    "marker_size = 3\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=1, \n",
    "    cols=2, \n",
    "    subplot_titles=(\"Ambient\", \"Projected\"), \n",
    "    specs=[[{'type': 'surface'}, {'type': 'surface'}]]\n",
    ")\n",
    "\n",
    "# Ambient vector field \n",
    "fig.add_trace(\n",
    "    go.Surface(\n",
    "        x=sphere_inputs[:, :, 0],\n",
    "        y=sphere_inputs[:, :, 1],\n",
    "        z=sphere_inputs[:, :, 2],\n",
    "        surfacecolor=sphere_outputs,\n",
    "        colorscale=['lightgrey', 'lightgrey'],\n",
    "        showscale=False,\n",
    "    ),\n",
    "    row=1, col=1\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Cone(\n",
    "        x=ambient_inputs[:, 0], \n",
    "        y=ambient_inputs[:, 1], \n",
    "        z=ambient_inputs[:, 2], \n",
    "        u=ambient_inputs[:, 3], \n",
    "        v=ambient_inputs[:, 4], \n",
    "        w=ambient_inputs[:, 5], \n",
    "        colorscale=['black', 'black'],\n",
    "        sizemode='scaled',\n",
    "        sizeref=1.2,\n",
    "        showscale=False,\n",
    "    ), \n",
    "    row=1, col=1\n",
    ")\n",
    "\n",
    "# Projected vector field\n",
    "fig.add_trace(\n",
    "    go.Surface(\n",
    "        x=sphere_inputs[:, :, 0],\n",
    "        y=sphere_inputs[:, :, 1],\n",
    "        z=sphere_inputs[:, :, 2],\n",
    "        surfacecolor=sphere_outputs,\n",
    "        colorscale=['lightgrey', 'lightgrey'],\n",
    "        showscale=False,\n",
    "    ),\n",
    "    row=1, col=2\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Cone(\n",
    "        x=projected_inputs[:, 0], \n",
    "        y=projected_inputs[:, 1], \n",
    "        z=projected_inputs[:, 2], \n",
    "        u=projected_inputs[:, 3], \n",
    "        v=projected_inputs[:, 4], \n",
    "        w=projected_inputs[:, 5], \n",
    "        colorscale=['black', 'black'],\n",
    "        sizemode='scaled',\n",
    "        sizeref=1.2,\n",
    "        showscale=False,\n",
    "    ), \n",
    "    row=1, col=2\n",
    ")\n",
    "\n",
    "fig.update_layout(\n",
    "    scene=dict(\n",
    "        xaxis=dict(visible=False),\n",
    "        yaxis=dict(visible=False),\n",
    "        zaxis=dict(visible=False),\n",
    "    ),\n",
    "    scene2=dict(\n",
    "        xaxis=dict(visible=False),\n",
    "        yaxis=dict(visible=False),\n",
    "        zaxis=dict(visible=False),\n",
    "    ),\n",
    "    width=800,\n",
    "    height=600,\n",
    "    showlegend=False,\n",
    ")\n",
    "fig.write_image(\"gvf_construction-projected.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the data as csv using the names of the variables\n",
    "data = [\n",
    "    ambient_inputs,\n",
    "    ambient_outputs,\n",
    "    projected_inputs,\n",
    "    projected_outputs,\n",
    "]\n",
    "\n",
    "names = [\n",
    "    \"ambient-inputs\",\n",
    "    \"ambient-outputs\",\n",
    "    \"projected-inputs\",\n",
    "    \"projected-outputs\",\n",
    "]\n",
    "\n",
    "\n",
    "for datum, name in zip(data, names):\n",
    "    pd.DataFrame(datum).to_csv(f\"{name}.csv\", header=False, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
