{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import config \n",
    "config.update(\"jax_enable_x64\", True)\n",
    "\n",
    "\n",
    "import jax \n",
    "import jax.numpy as jnp\n",
    "import gpjax \n",
    "import optax \n",
    "from gpjax.distributions import GaussianDistribution\n",
    "from gpjax.typing import Float\n",
    "from jax.tree_util import Partial\n",
    "import tensorflow_probability.substrates.jax as tfp\n",
    "import tensorflow_probability.substrates.jax.distributions as tfd\n",
    "from jaxtyping import Key\n",
    "\n",
    "from matplotlib import pyplot as plt \n",
    "\n",
    "from dataclasses import dataclass, InitVar\n",
    "from abc import abstractmethod\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def sph_dot_product(sph1: Float, sph2: Float) -> Float:\n",
    "    \"\"\"\n",
    "    Computes dot product in R^3 of two points on the sphere in spherical coordinates.\n",
    "    \"\"\"\n",
    "    colat1, lon1 = sph1[..., 0], sph1[..., 1]\n",
    "    colat2, lon2 = sph2[..., 0], sph2[..., 1]\n",
    "    return jnp.sin(colat1) * jnp.sin(colat2) * jnp.cos(lon1 - lon2) + jnp.cos(colat1) * jnp.cos(colat2)\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))\n",
    "def gegenbauer(x, max_ell: int, alpha: float = 0.5):\n",
    "    \"\"\"\n",
    "    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.\n",
    "\n",
    "    Cᵅ₀(x) = 1\n",
    "    Cᵅ₁(x) = 2αx\n",
    "    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n\n",
    "\n",
    "    Args:\n",
    "        level: The order of the polynomial.\n",
    "        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.\n",
    "        x: Input array.\n",
    "\n",
    "    Returns:\n",
    "        The Gegenbauer polynomial evaluated at `x`.\n",
    "    \"\"\"\n",
    "    C_0 = jnp.ones_like(x, dtype=x.dtype)\n",
    "    C_1 = 2 * alpha * x\n",
    "    \n",
    "    res = jnp.empty((max_ell + 1, *x.shape), dtype=x.dtype)\n",
    "    res = res.at[0].set(C_0)\n",
    "\n",
    "    def step(n, res_and_Cs):\n",
    "        res, C, C_prev = res_and_Cs\n",
    "        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C\n",
    "        res = res.at[n].set(C)\n",
    "        return res, C, C_prev\n",
    "    \n",
    "    return jax.lax.cond(\n",
    "        max_ell == 0,\n",
    "        lambda: res,\n",
    "        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[1].set(C_1), C_1, C_0))[0],\n",
    "    )\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call \n",
    "def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:\n",
    "    \"\"\"\n",
    "    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.\n",
    "\n",
    "    Cᵅ₀(x) = 1\n",
    "    Cᵅ₁(x) = 2αx\n",
    "    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n\n",
    "\n",
    "    Args:\n",
    "        level: The order of the polynomial.\n",
    "        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.\n",
    "        x: Input array.\n",
    "\n",
    "    Returns:\n",
    "        The Gegenbauer polynomial evaluated at `x`.\n",
    "    \"\"\"\n",
    "    C_0 = jnp.ones_like(x, dtype=x.dtype)\n",
    "    C_1 = 2 * alpha * x\n",
    "\n",
    "    def step(Cs_and_n):\n",
    "        C, C_prev, n = Cs_and_n\n",
    "        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C\n",
    "        return C, C_prev, n + 1\n",
    "\n",
    "    def cond(Cs_and_n):\n",
    "        n = Cs_and_n[2]\n",
    "        return n <= ell\n",
    "\n",
    "    return jax.lax.cond(\n",
    "        ell == 0,\n",
    "        lambda: C_0,\n",
    "        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],\n",
    "    )\n",
    "\n",
    "\n",
    "# NOTE jitting this doesn't help\n",
    "def sph_gegenbauer(x, y, max_ell: int, alpha: float = 0.5):\n",
    "    return gegenbauer(x=sph_dot_product(x, y), max_ell=max_ell, alpha=alpha)\n",
    "\n",
    "\n",
    "# NOTE jitting this doesn't help\n",
    "def sph_gegenbauer_single(x, y, ell: int, alpha: float = 0.5):\n",
    "    return gegenbauer_single(x=sph_dot_product(x, y), ell=ell, alpha=alpha)\n",
    "\n",
    "\n",
    "def array(x):\n",
    "    return jnp.array(x, dtype=jnp.float64)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def sph_to_car(sph):\n",
    "    \"\"\"\n",
    "    From spherical (colat, lon) coordinates to cartesian, single point.\n",
    "    \"\"\"\n",
    "    colat, lon = sph[..., 0], sph[..., 1]\n",
    "    z = jnp.cos(colat)\n",
    "    r = jnp.sin(colat)\n",
    "    x = r * jnp.cos(lon)\n",
    "    y = r * jnp.sin(lon)\n",
    "    return jnp.stack([x, y, z], axis=-1)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def car_to_sph(car):\n",
    "    x, y, z = car[..., 0], car[..., 1], car[..., 2]\n",
    "    colat = jnp.arccos(z)\n",
    "    lon = jnp.arctan2(y, x)\n",
    "    return jnp.stack([colat, lon], axis=-1)\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Conversion between hodge and flat coordinates.\n",
    "\"\"\"\n",
    "\n",
    "@jax.jit\n",
    "def flatten_matrix(matrix):\n",
    "    \"\"\"\n",
    "    Input matrix has shape (nx, ny, *block_shape).\n",
    "    \"\"\"\n",
    "    out = jnp.vstack([jnp.hstack([block for block in row_blocks]) for row_blocks in matrix])\n",
    "    return out\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('spherical',))\n",
    "def unflatten_matrix(matrix, spherical=True):\n",
    "    \"\"\"\n",
    "    Input matrix has shape (nx, ny).\n",
    "    \"\"\"\n",
    "    dim = 2 if spherical else 3\n",
    "    out = jnp.array([\n",
    "        jnp.split(row_block, indices_or_sections=matrix.shape[1]//dim, axis=1)\n",
    "        for row_block in jnp.split(matrix, indices_or_sections=matrix.shape[0]//dim, axis=0)\n",
    "    ])\n",
    "    return out\n",
    "\n",
    "\n",
    "def flatten_coord(coord):\n",
    "    \"\"\"\n",
    "    Flatten coordinates to 1d array.\n",
    "    \"\"\"\n",
    "    return jnp.ravel(coord)\n",
    "\n",
    "\n",
    "def unflatten_coord(coord_flat, spherical=True, extra_dims=[]):\n",
    "    \"\"\"\n",
    "    Un-flatten coordinates to (n, 2, *extra_dims).\n",
    "    \"\"\"\n",
    "    return coord_flat.reshape(-1, 2 if spherical else 3, *extra_dims)\n",
    "\n",
    "\n",
    "from pathlib import Path\n",
    "from typing import Callable\n",
    "\n",
    "import numpy as np\n",
    "from jax import Array\n",
    "\n",
    "\n",
    "class FundamentalSystemNotPrecomputedError(ValueError):\n",
    "\n",
    "    def __init__(self, dimension: int):\n",
    "        message = f\"Fundamental system for dimension {dimension} has not been precomputed.\"\n",
    "        super().__init__(message)\n",
    "\n",
    "\n",
    "def fundamental_set_loader(dimension: int, load_dir=\"fundamental_system\") -> Callable[[int], Array]:\n",
    "    load_dir = Path(\"../../\") / load_dir\n",
    "    file_name = load_dir / f\"fs_{dimension}D.npz\"\n",
    "\n",
    "    cache = {}\n",
    "    if file_name.exists():\n",
    "        with np.load(file_name) as f:\n",
    "            cache = {k: v for (k, v) in f.items()}\n",
    "    else:\n",
    "        raise FundamentalSystemNotPrecomputedError(dimension)\n",
    "\n",
    "    def load(degree: int) -> Array:\n",
    "        key = f\"degree_{degree}\"\n",
    "        if key not in cache:\n",
    "            raise ValueError(f\"key: {key} not in cache.\")\n",
    "        return cache[key]\n",
    "\n",
    "    return load\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class SphericalHarmonics(gpjax.Module):\n",
    "    \"\"\"\n",
    "    Spherical harmonics inducing features for sparse inference in Gaussian processes.\n",
    "\n",
    "    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,\n",
    "    as such, they form an orthogonal basis.\n",
    "\n",
    "    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute\n",
    "    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via\n",
    "    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during\n",
    "    initialisation.\n",
    "\n",
    "    Attributes:\n",
    "        num_frequencies: The number of frequencies, up to which, we compute the harmonics.\n",
    "\n",
    "    Returns:\n",
    "        An instance of the spherical harmonics features.\n",
    "    \"\"\"\n",
    "\n",
    "    max_ell: int = gpjax.base.static_field()\n",
    "    sphere_dim: int = gpjax.base.static_field()\n",
    "    alpha: float = gpjax.base.static_field(init=False)\n",
    "    orth_basis: Array = gpjax.base.param_field(init=False, trainable=False)\n",
    "    Vs: list[Array] = gpjax.base.param_field(init=False, trainable=False)\n",
    "\n",
    "    @property\n",
    "    def levels(self):\n",
    "        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)\n",
    "    \n",
    "    def __post_init__(self) -> None:\n",
    "        \"\"\"\n",
    "        Initialise the parameters of the spherical harmonic features and return a `Param` object.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        dim = self.sphere_dim + 1\n",
    "\n",
    "        # Try loading a pre-computed fundamental set.\n",
    "        fund_set = fundamental_set_loader(dim)\n",
    "\n",
    "        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.\n",
    "        self.alpha = (dim - 2.0) / 2.0\n",
    "\n",
    "        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.\n",
    "        self.Vs = [fund_set(n) for n in self.levels]\n",
    "\n",
    "        # pre-compute and save the orthogonal basis \n",
    "        self.orth_basis = self._orthogonalise_basis()\n",
    "\n",
    "\n",
    "    @property\n",
    "    def Ls(self) -> list[Array]:\n",
    "        \"\"\"\n",
    "        Alias for the orthogonal basis at every frequency.\n",
    "        \"\"\"\n",
    "        return self.orth_basis\n",
    "\n",
    "    @property\n",
    "    def num_phase_in_frequency(self) -> list[int]:\n",
    "        \"\"\"\n",
    "        Get the total number of phases/harmonics at every frequency.\n",
    "\n",
    "        Returns:\n",
    "            A list with the number of phases per frequency.\n",
    "        \"\"\"\n",
    "        return jax.tree.map(lambda x: x.shape[0], self.Vs)\n",
    "\n",
    "    @property\n",
    "    def num_inducing(self) -> int:\n",
    "        \"\"\"\n",
    "        Computes the total number of inducing features, as the sum of all phases.\n",
    "\n",
    "        Args:\n",
    "            param: A `Param` initialised with the spherical harmonic features.\n",
    "\n",
    "        Returns:\n",
    "            The total number of inducing features.\n",
    "        \"\"\"\n",
    "        return sum(self.num_phase_in_frequency)\n",
    "\n",
    "    def _orthogonalise_basis(self) -> None:\n",
    "        \"\"\"\n",
    "        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.\n",
    "        \"\"\"\n",
    "        alpha = self.alpha\n",
    "        levels = jnp.split(self.levels, self.max_ell + 1)\n",
    "        const = alpha / (alpha + self.levels.astype(jnp.float64))\n",
    "        const = jnp.split(const, self.max_ell + 1)\n",
    "\n",
    "        def _func(v, n, c):\n",
    "            x = jnp.matmul(v, v.T)\n",
    "            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)\n",
    "            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))\n",
    "            return L\n",
    "\n",
    "        return jax.tree.map(_func, self.Vs, levels, const)\n",
    "\n",
    "    def custom_gegenbauer_single(self, x, ell, alpha):\n",
    "        return gegenbauer(x, self.max_ell, alpha)[ell]\n",
    "\n",
    "    @jax.jit\n",
    "    def polynomial_expansion(self, X: Float[Array, \"N D\"]) -> Float[Array, \"M N\"]:\n",
    "        \"\"\"\n",
    "        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.\n",
    "\n",
    "        Args:\n",
    "            X: Input Array.\n",
    "\n",
    "        Returns:\n",
    "            The harmonics evaluated at the input as a polynomial expansion of the basis.\n",
    "        \"\"\"\n",
    "        levels = jnp.split(self.levels, self.max_ell + 1)\n",
    "\n",
    "        def _func(v, n, L):\n",
    "            vxT = jnp.dot(v, X.T)\n",
    "            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)\n",
    "            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)\n",
    "            return harmonic\n",
    "\n",
    "        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)\n",
    "        return jnp.concatenate(harmonics, axis=0)\n",
    "    \n",
    "    def __eq__(self, other: \"SphericalHarmonics\") -> bool:\n",
    "        \"\"\"\n",
    "        Check if two spherical harmonic features are equal.\n",
    "\n",
    "        Args:\n",
    "            other: The other spherical harmonic features.\n",
    "\n",
    "        Returns:\n",
    "            A boolean indicating if the two features are equal.\n",
    "        \"\"\"\n",
    "        # Given the first two parameters, the rest are deterministic. \n",
    "        # The user must not mutate all other fields, but that is not enforced for now.\n",
    "        return (\n",
    "            self.max_ell == other.max_ell \n",
    "            and self.sphere_dim == other.sphere_dim \n",
    "        )\n",
    "    \n",
    "\n",
    "\n",
    "import warnings \n",
    "from typing import Optional\n",
    "from jaxtyping import Num\n",
    "from gpjax.typing import ScalarFloat\n",
    "\n",
    "\n",
    "def _check_precision(\n",
    "    X: Optional[Num[Array, \"...\"]], y: Optional[Num[Array, \"...\"]]\n",
    ") -> None:\n",
    "    r\"\"\"Checks the precision of $`X`$ and $`y`.\"\"\"\n",
    "    if X is not None and X.dtype != jnp.float64:\n",
    "        warnings.warn(\n",
    "            \"X is not of type float64. \"\n",
    "            f\"Got X.dtype={X.dtype}. This may lead to numerical instability. \",\n",
    "            stacklevel=2,\n",
    "        )\n",
    "\n",
    "    if y is not None and y.dtype != jnp.float64:\n",
    "        warnings.warn(\n",
    "            \"y is not of type float64.\"\n",
    "            f\"Got y.dtype={y.dtype}. This may lead to numerical instability.\",\n",
    "            stacklevel=2,\n",
    "        )\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class VectorDataset(gpjax.Dataset):\n",
    "    X: Optional[Num[Array, \"N D\"]] = None\n",
    "    y: Optional[Num[Array, \"M Q\"]] = None\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        r\"\"\"Checks that the shapes of $`X`$ and $`y`$ are compatible,\n",
    "        and provides warnings regarding the precision of $`X`$ and $`y`$.\"\"\"\n",
    "        # _check_shape(self.X, self.y)\n",
    "        _check_precision(self.X, self.y)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class VectorZeroMean(gpjax.mean_functions.AbstractMeanFunction):\n",
    "    dim: int = gpjax.base.static_field(1)\n",
    "\n",
    "    def __call__(self, x: Float[Array, \"N D\"]) -> Float[Array, \"N E\"]:\n",
    "        return jnp.zeros((x.shape[0] * self.dim))\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class AnalyticalVectorGaussianIntegrator(gpjax.integrators.AbstractIntegrator):\n",
    "    r\"\"\"Compute the analytical integral of a Gaussian likelihood.\n",
    "\n",
    "    When the likelihood function is Gaussian, the integral can be computed in closed\n",
    "    form. For a Gaussian likelihood $`p(y|f) = \\mathcal{N}(y|f, \\sigma^2)`$ and a\n",
    "    variational distribution $`q(f) = \\mathcal{N}(f|m, s)`$, the expected\n",
    "    log-likelihood is given by\n",
    "    ```math\n",
    "    \\mathbb{E}_{q(f)}[\\log p(y|f)] = -\\frac{1}{2}\\left(\\log(2\\pi\\sigma^2) + \\frac{1}{\\sigma^2}((y-m)^2 + s)\\right)\n",
    "    ```\n",
    "    \"\"\"\n",
    "\n",
    "    def integrate(\n",
    "        self,\n",
    "        fun: Callable,\n",
    "        y: Float[Array, \"N D\"],\n",
    "        mean: Float[Array, \"N D\"],\n",
    "        covariance: Float[Array, \"N D D\"],\n",
    "        likelihood: gpjax.likelihoods.Gaussian,\n",
    "    ) -> Float[Array, \" N\"]:\n",
    "        r\"\"\"Compute a Gaussian integral.\n",
    "\n",
    "        Args:\n",
    "            fun (Callable): The Gaussian likelihood to be integrated.\n",
    "            y (Float[Array, 'N D']): The observed response variable.\n",
    "            mean (Float[Array, 'N D']): The mean of the variational distribution.\n",
    "            covariance (Float[Array, 'N D D']): The block diagonal covariance of the variational\n",
    "                distribution.\n",
    "            likelihood (Gaussian): The Gaussian likelihood function.\n",
    "\n",
    "        Returns:\n",
    "            Float[Array, 'N']: The expected log likelihood.\n",
    "        \"\"\"\n",
    "        d = y.shape[-1]\n",
    "        obs_var = likelihood.obs_stddev.squeeze() ** 2 # [1]\n",
    "        sq_error = jnp.sum(jnp.square(y - mean), axis=-1) # [N]\n",
    "        log2pi = jnp.log(2.0 * jnp.pi) # [1]\n",
    "        # jax.debug.print(f\"{covariance.shape=}, {jnp.trace(covariance, axis1=1, axis2=2).shape=}\")\n",
    "        val = (\n",
    "            d * (log2pi + jnp.log(obs_var)) # [1]\n",
    "            + (sq_error + jnp.trace(covariance, axis1=1, axis2=2)) / obs_var # ([N] + [N]) / [1] -> [N]\n",
    "        )\n",
    "        return -0.5 * val\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class VectorGaussian(gpjax.likelihoods.Gaussian):\n",
    "    integrator: gpjax.integrators.AbstractIntegrator = gpjax.base.static_field(AnalyticalVectorGaussianIntegrator())\n",
    "\n",
    "\n",
    "import abc \n",
    "import cola\n",
    "from typing import TypeVar \n",
    "from jaxtyping import Float, Num\n",
    "from gpjax.typing import ScalarFloat\n",
    "\n",
    "\n",
    "Kernel = TypeVar(\"Kernel\", bound=\"gpjax.kernels.base.AbstractKernel\")\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class AbstractVectorKernelComputation:\n",
    "    r\"\"\"Abstract class for vector kernel computations.\"\"\"\n",
    "\n",
    "    def gram(\n",
    "        self,\n",
    "        kernel: Kernel,\n",
    "        x: Num[Array, \"N D\"],\n",
    "    ) -> cola.ops.LinearOperator:\n",
    "        r\"\"\"Compute Gram covariance operator of the kernel function.\n",
    "\n",
    "        Args:\n",
    "            kernel (AbstractKernel): the kernel function.\n",
    "            x (Num[Array, \"N N\"]): The inputs to the kernel function.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            LinearOperator: Gram covariance operator of the kernel function.\n",
    "        \"\"\"\n",
    "        Kxx = self.cross_covariance(kernel, x, x)\n",
    "        return cola.PSD(cola.ops.Dense(Kxx))\n",
    "\n",
    "    @abc.abstractmethod\n",
    "    def cross_covariance(\n",
    "        self, kernel: Kernel, x: Num[Array, \"N D\"], y: Num[Array, \"M D\"]\n",
    "    ) -> Float[Array, \"N M\"]:\n",
    "        r\"\"\"For a given kernel, compute the NxM gram matrix on an a pair\n",
    "        of input matrices with shape NxD and MxD.\n",
    "\n",
    "        Args:\n",
    "            kernel (AbstractKernel): the kernel function.\n",
    "            x (Num[Array,\"N D\"]): The first input matrix.\n",
    "            y (Num[Array,\"M D\"]): The second input matrix.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            Float[Array, \"N M\"]: The computed cross-covariance.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def diagonal(self, kernel: Kernel, inputs: Num[Array, \"N D\"]) -> cola.ops.BlockDiag:\n",
    "        r\"\"\"For a given kernel, compute the elementwise diagonal of the\n",
    "        NxN gram matrix on an input matrix of shape NxD.\n",
    "\n",
    "        Args:\n",
    "            kernel (AbstractKernel): the kernel function.\n",
    "            inputs (Float[Array, \"N D\"]): The input matrix.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            Diagonal: The computed diagonal variance entries.\n",
    "        \"\"\"\n",
    "        return cola.PSD(cola.ops.BlockDiag(diag=jax.vmap(lambda x: kernel(x, x))(inputs)))\n",
    "    \n",
    "\n",
    "class DenseVectorKernelComputation(AbstractVectorKernelComputation):\n",
    "    r\"\"\"Dense kernel computation class. Operations with the kernel assume\n",
    "    a dense gram matrix structure.\n",
    "    \"\"\"\n",
    "\n",
    "    def cross_covariance(\n",
    "        self, kernel: Kernel, x: Float[Array, \"N D\"], y: Float[Array, \"M D\"]\n",
    "    ) -> Float[Array, \"2N 2M\"]:\n",
    "        r\"\"\"Compute the cross-covariance matrix.\n",
    "\n",
    "        For a given kernel, compute the NxM covariance matrix on a pair of input\n",
    "        matrices of shape $`NxD`$ and $`MxD`$.\n",
    "\n",
    "        Args:\n",
    "            kernel (Kernel): the kernel function.\n",
    "            x (Float[Array,\"N D\"]): The input matrix.\n",
    "            y (Float[Array,\"M D\"]): The input matrix.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            Float[Array, \"2N 2M\"]: The computed cross-covariance.\n",
    "        \"\"\"\n",
    "        cross_cov = jax.vmap(lambda x: jax.vmap(lambda y: kernel(x, y))(y))(x)\n",
    "        # flatten for consistency\n",
    "        return flatten_matrix(cross_cov)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class VectorZeroMean(gpjax.mean_functions.AbstractMeanFunction):\n",
    "    space_dim: int = gpjax.base.static_field(2)\n",
    "    output_dim: int = gpjax.base.static_field(1)\n",
    "\n",
    "    def __call__(self, x: Float[Array, \"N D\"]) -> Float[Array, \"N E\"]:\n",
    "        return jnp.zeros((x.shape[0] * self.space_dim, self.output_dim))\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class AbstractVectorKernel(gpjax.kernels.AbstractKernel):\n",
    "    r\"\"\"Base vector kernel class.\"\"\"\n",
    "\n",
    "    compute_engine: AbstractVectorKernelComputation = gpjax.base.static_field(DenseVectorKernelComputation())\n",
    "    active_dims: Optional[list[int]] = gpjax.base.static_field(None)\n",
    "    name: str = gpjax.base.static_field(\"AbstractVectorKernel\")\n",
    "    \n",
    "\n",
    "@jax.jit\n",
    "def tangent_basis_normalization_matrix(x: Float[Array, \"2\"]) -> Float[Array, \"2 2\"]:\n",
    "    return jnp.array([\n",
    "        [1.0, 0.0], \n",
    "        [0.0, 1.0 / jnp.sin(x[0])],\n",
    "    ])\n",
    "\n",
    "\n",
    "hodge_star_matrix = jnp.array([\n",
    "    [0.0, 1.0],\n",
    "    [-1.0, 0.0],\n",
    "])\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=('min_value', ))\n",
    "def _ensure_colatitude_nonzero(x: Float[Array, \"N 2\"], min_value: float) -> Float[Array, \"N 2\"]:\n",
    "    return x.at[..., 0].set(jnp.where(x[..., 0] == 0, min_value, x[..., 0]))\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def matern_spectral_density(ell: ScalarFloat, kappa: ScalarFloat, nu: ScalarFloat, variance: ScalarFloat) -> ScalarFloat:\n",
    "    lambda_ells = ell * (ell + 1)\n",
    "    \n",
    "    # Compute log of Phi_nu_ells to avoid underflow\n",
    "    log_Phi_nu_ells = -(nu + 1) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))\n",
    "    \n",
    "    # Subtract max value for numerical stability\n",
    "    max_log_Phi = jnp.max(log_Phi_nu_ells)\n",
    "    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)\n",
    "    \n",
    "    # Normalize the density, so that it sums to 1\n",
    "    num_harmonics_per_ell = 2 * ell + 1\n",
    "    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)\n",
    "    return variance * Phi_nu_ells / normalizer\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class AbstractHodgeKernel(AbstractVectorKernel):\n",
    "    nu: ScalarFloat = gpjax.param_field(jnp.array(2.5), bijector=tfp.bijectors.Softplus())\n",
    "    kappa: ScalarFloat = gpjax.param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())\n",
    "    variance: ScalarFloat = gpjax.param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())\n",
    "    alpha: float = gpjax.base.static_field(0.5)\n",
    "    max_ell: int = gpjax.base.static_field(10)\n",
    "    colatitude_min_value: float = gpjax.base.static_field(1e-12) # NOTE not sure what exact value to use here\n",
    "    spherical_harmonic_fields: \"AbstractSphericalHarmonicFields\" = gpjax.base.static_field(None)\n",
    "\n",
    "    @property\n",
    "    def ells(self) -> Float[Array, \"\"]:\n",
    "        return jnp.arange(1, self.max_ell + 1)\n",
    "    \n",
    "    def spectral_density(self) -> ScalarFloat:\n",
    "        return matern_spectral_density(self.ells, self.kappa, self.nu, self.variance)\n",
    "    \n",
    "    @jax.jit\n",
    "    def weighted_gegenbauer(self, x: Float, y: Float, weights: Float) -> Float:\n",
    "        lambda_ells = self.ells * (self.ells + 1)\n",
    "        values = sph_gegenbauer(x, y, self.max_ell, self.alpha)[1:]\n",
    "        return weights * values / lambda_ells\n",
    "    \n",
    "    @jax.jit\n",
    "    def dd_weighted_gegenbauer(self, x: Float[Array, \"2\"], y: Float[Array, \"2\"], weights: Float) -> Float[Array, \"2 2\"]:\n",
    "        return jax.jacfwd(jax.jacfwd(lambda x, y: self.weighted_gegenbauer(x, y, weights), argnums=0), argnums=1)(x, y)\n",
    "    \n",
    "    @jax.jit\n",
    "    def validate_inputs(self, x: Float[Array, \"2\"], y: Float[Array, \"2\"]) -> tuple[Float[Array, \"2\"], Float[Array, \"2\"]]:\n",
    "        x = _ensure_colatitude_nonzero(x, self.colatitude_min_value)\n",
    "        y = _ensure_colatitude_nonzero(y, self.colatitude_min_value)\n",
    "        return x, y\n",
    "\n",
    "    def _pathwise_sample_from_weights(self, x: Float[Array, \"S N 2\"], w: Float[Array, \"I\"]) -> Float[Array, \"S N 2\"]:\n",
    "        Phi_x = jax.vmap(self.spherical_harmonic_fields)(x) # [S N I 2]\n",
    "        ahats_per_frequency = self.spectral_density() # [I]\n",
    "        ahats_per_phase = jnp.repeat(\n",
    "            ahats_per_frequency, \n",
    "            self.spherical_harmonic_fields.num_phases_per_frequency,\n",
    "            total_repeat_length=self.spherical_harmonic_fields.num_phases\n",
    "        ) # [I]\n",
    "        tilde_Phi_x = jnp.einsum('snid, i -> snid', Phi_x, jnp.sqrt(ahats_per_phase)) # [N I 2]\n",
    "        return jnp.einsum('snid, si -> snd', tilde_Phi_x, w)\n",
    "    \n",
    "    @Partial(jax.jit, static_argnames=('num_samples',))\n",
    "    def pathwise_sample_from_weights(\n",
    "        self, \n",
    "        x: Float[Array, \"N 2\"] | Float[Array, \"S N 2\"],\n",
    "        w: Float[Array, \"I\"], \n",
    "        num_samples: int = 1\n",
    "    ) -> Float[Array, \"N 2\"]:\n",
    "        x_shape = jnp.broadcast_shapes(x.shape, (num_samples, 1, 1))\n",
    "        x = jnp.broadcast_to(x, x_shape)\n",
    "        return self._pathwise_sample_from_weights(x, w)\n",
    "    \n",
    "    @Partial(jax.jit, static_argnames=('num_samples',))\n",
    "    def sample_weights(self, key: Key, num_samples: int = 1) -> Float[Array, \"I\"]:\n",
    "        return jax.random.normal(key, shape=(num_samples, self.spherical_harmonic_fields.num_phases), dtype=jnp.float64)\n",
    "    \n",
    "    def pathwise_sample(self, key: Key, x: Float[Array, \"N 2\"], num_samples: int = 1) -> Float[Array, \"N 2\"]:\n",
    "        w = self.sample_weights(key, num_samples)\n",
    "        return self.pathwise_sample_from_weights(x, w, num_samples)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class HodgeMaternCurlFreeKernel(AbstractHodgeKernel):\n",
    "\n",
    "    def __post_init__(self):\n",
    "        try: \n",
    "            self.spherical_harmonic_fields = CurlFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)\n",
    "        except FundamentalSystemNotPrecomputedError as e:\n",
    "            warnings.warn(\n",
    "                f\"{e}\",\n",
    "                f\"Pathwise sampling will not be available unless max_ell is sufficiently reduced.\"\n",
    "            )\n",
    "\n",
    "    def __call__(self, x: Float[Array, \"2\"], y: Float[Array, \"2\"]) -> Float[Array, \"2 2\"]:\n",
    "        x, y = self.validate_inputs(x, y)\n",
    "        weights = self.spectral_density() * (2 * self.ells + 1)\n",
    "        dd = jnp.sum(self.dd_weighted_gegenbauer(x, y, weights=weights), axis=0)\n",
    "\n",
    "        Nx, Ny = tangent_basis_normalization_matrix(x), tangent_basis_normalization_matrix(y)\n",
    "        return Nx.T @ dd @ Ny\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class HodgeMaternDivFreeKernel(AbstractHodgeKernel):\n",
    "\n",
    "    def __post_init__(self):\n",
    "        try: \n",
    "            self.spherical_harmonic_fields = DivFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)\n",
    "        except FundamentalSystemNotPrecomputedError as e:\n",
    "            warnings.warn(\n",
    "                f\"{e}\",\n",
    "                f\"Pathwise sampling will not be available unless max_ell is sufficiently reduced.\"\n",
    "            )\n",
    "\n",
    "    @jax.jit\n",
    "    def __call__(self, x: Float[Array, \"2\"], y: Float[Array, \"2\"]) -> Float[Array, \"2 2\"]:\n",
    "        x, y = self.validate_inputs(x, y)\n",
    "        weights = self.spectral_density() * (2 * self.ells + 1)\n",
    "        dd = jnp.sum(self.dd_weighted_gegenbauer(x, y, weights=weights), axis=0)\n",
    "\n",
    "        Nx, Ny = tangent_basis_normalization_matrix(x), tangent_basis_normalization_matrix(y)\n",
    "        H = hodge_star_matrix\n",
    "        return H.T @ Nx.T @ dd @ Ny @ H\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class HodgeMaternKernel(AbstractVectorKernel):\n",
    "    kappa: InitVar[ScalarFloat] = 1.0\n",
    "    nu: InitVar[ScalarFloat] = 2.5\n",
    "    variance: InitVar[ScalarFloat] = 1.0\n",
    "    colatitude_min_value: InitVar[ScalarFloat] = 1e-12\n",
    "\n",
    "    max_ell: int = gpjax.base.static_field(10)\n",
    "    curl_free_kernel: HodgeMaternCurlFreeKernel = gpjax.param_field(init=False)\n",
    "    div_free_kernel: HodgeMaternDivFreeKernel = gpjax.param_field(init=False)\n",
    "\n",
    "    def __post_init__(self, kappa, nu, variance, colatitude_min_value):\n",
    "        self.curl_free_kernel = HodgeMaternCurlFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=colatitude_min_value)\n",
    "        self.div_free_kernel = HodgeMaternDivFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=colatitude_min_value)\n",
    "\n",
    "    def spectral_density(self):\n",
    "        return jnp.concat([self.curl_free_kernel.spectral_density(), self.div_free_kernel.spectral_density()])\n",
    "    \n",
    "    def __call__(self, x: Float[Array, \"2\"], y: Float[Array, \"2\"]) -> Float[Array, \"2 2\"]:\n",
    "        return self.curl_free_kernel(x, y) + self.div_free_kernel(x, y)\n",
    "    \n",
    "    def sample_weights(self, key: Key, num_samples: int = 1) -> Float[Array, \"I\"]:\n",
    "        return jnp.concatenate([\n",
    "            self.curl_free_kernel.sample_weights(key, num_samples), \n",
    "            self.div_free_kernel.sample_weights(key, num_samples),\n",
    "        ], axis=-1)\n",
    "\n",
    "    def pathwise_sample_from_weights(self, x: Float[Array, \"N 2\"], w: Float[Array, \"I\"], num_samples: int = 1) -> Float[Array, \"N 2\"]:\n",
    "        curl_free_w, div_free_w = jnp.split(w, 2, axis=-1)\n",
    "        curl_free_sample = self.curl_free_kernel.pathwise_sample_from_weights(x, curl_free_w, num_samples)\n",
    "        div_free_sample = self.div_free_kernel.pathwise_sample_from_weights(x, div_free_w, num_samples)\n",
    "        return curl_free_sample + div_free_sample\n",
    "    \n",
    "    def pathwise_sample(self, key: Key, x: Float[Array, \"N 2\"], num_samples: int = 1) -> Float[Array, \"N 2\"]:\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x:  The input locations. Can be [N 2] or [S N 2]. Internally, it is broadcasted to [S N 2], then\n",
    "                flattened, then processed, and then reshaped back to [S N 2].\n",
    "        \n",
    "        \"\"\"\n",
    "        w = self.sample_weights(key, num_samples) # [S I]\n",
    "        return self.pathwise_sample_from_weights(x, w, num_samples)\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class AbstractSphericalHarmonicFields(gpjax.Module):\n",
    "    max_ell: int = gpjax.base.static_field(10)\n",
    "    sphere_dim: int = gpjax.base.static_field(2)\n",
    "    _colatitude_min_value: float = gpjax.base.static_field(1e-12)\n",
    "    spherical_harmonics: SphericalHarmonics = gpjax.base.static_field(init=False)\n",
    "    num_phases_per_frequency: Float[Array, \"L\"] = gpjax.base.param_field(init=False, trainable=False)\n",
    "    num_phases: int = gpjax.base.static_field(init=False)\n",
    "    num_fields: int = gpjax.base.static_field(init=False)\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.spherical_harmonics = SphericalHarmonics(max_ell=self.max_ell, sphere_dim=self.sphere_dim)\n",
    "        num_phases_per_frequency = self.spherical_harmonics.num_phase_in_frequency[1:]\n",
    "        self.num_phases_per_frequency = jnp.array(num_phases_per_frequency)\n",
    "        self.num_phases = sum(num_phases_per_frequency)\n",
    "\n",
    "    @jax.jit\n",
    "    def _sph_polynomial_expansion(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N 2\"]:\n",
    "        ells = jnp.arange(1, self.max_ell + 1)\n",
    "        lambda_ells = ells * (ells + 1)\n",
    "        normalization_factor = jnp.repeat(\n",
    "            jnp.sqrt(lambda_ells),\n",
    "            self.num_phases_per_frequency,\n",
    "            total_repeat_length=self.num_phases\n",
    "        )\n",
    "        return self.spherical_harmonics.polynomial_expansion(sph_to_car(x))[1:] / normalization_factor\n",
    "\n",
    "    @jax.jit\n",
    "    def _field_polynomial_expansion_single(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N 2\"]:\n",
    "        Nx = tangent_basis_normalization_matrix(x)\n",
    "        return jax.jacfwd(self._sph_polynomial_expansion)(x) @ Nx\n",
    "    \n",
    "    @jax.jit\n",
    "    def _field_polynomial_expansion(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N I 2\"]:\n",
    "        x = _ensure_colatitude_nonzero(x, self._colatitude_min_value)\n",
    "        return jax.vmap(self._field_polynomial_expansion_single)(x)\n",
    "\n",
    "    @abstractmethod\n",
    "    def __call__(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N 2\"]:\n",
    "        pass \n",
    "    \n",
    "    def __eq__(self, other: \"AbstractSphericalHarmonicFields\") -> bool:\n",
    "        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim and self._colatitude_min_value == other._colatitude_min_value\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class CurlFreeSphericalHarmonicFields(AbstractSphericalHarmonicFields):\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        super().__post_init__()\n",
    "        self.num_fields = self.num_phases\n",
    "\n",
    "    @jax.jit\n",
    "    def __call__(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N I 2\"]:\n",
    "        return self._field_polynomial_expansion(x)\n",
    "    \n",
    "    def __eq__(self, other: \"CurlFreeSphericalHarmonicFields\") -> bool:\n",
    "        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class DivFreeSphericalHarmonicFields(AbstractSphericalHarmonicFields):\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        super().__post_init__()\n",
    "        self.num_fields = self.num_phases\n",
    "\n",
    "    @jax.jit\n",
    "    def __call__(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N I 2\"]:\n",
    "        H = hodge_star_matrix\n",
    "        return self._field_polynomial_expansion(x) @ H\n",
    "    \n",
    "    def __eq__(self, other: \"DivFreeSphericalHarmonicFields\") -> bool:\n",
    "        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class SphericalHarmonicFields(AbstractSphericalHarmonicFields):\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        super().__post_init__()\n",
    "        self.num_fields = 2 * self.num_phases\n",
    "\n",
    "    @jax.jit\n",
    "    def __call__(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N 2I 2\"]:\n",
    "        \"\"\"\n",
    "        Returns curl-free and divergence-free fields concatenated.\n",
    "        \"\"\"\n",
    "        H = hodge_star_matrix\n",
    "        v = self._field_polynomial_expansion(x) # [N I 2]\n",
    "        return jnp.concat([v, v @ H], axis=-2)\n",
    "    \n",
    "    def __eq__(self, other: \"SphericalHarmonicFields\") -> bool:\n",
    "        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class AbstractVectorSHF(gpjax.variational_families.AbstractVariationalFamily):\n",
    "    r\"\"\"The orthonormal generalized variational family of probability distributions.\n",
    "\n",
    "    The variational family is $`q(f(\\cdot)) = \\int p(f(\\cdot)\\mid u) q(u) \\mathrm{d}u`$, where\n",
    "    $`u = f(z)`$ are the function values at the inducing inputs $`z`$\n",
    "    and the distribution over the inducing inputs is\n",
    "    $`q(u) = \\mathcal{N}(\\mu, S)`$.  We parameterise this over\n",
    "    $`\\mu`$ and $`sqrt`$ with $`S = sqrt sqrt^{\\top}`$.\n",
    "    \"\"\"\n",
    "    max_ell: int = gpjax.base.static_field(1)\n",
    "    jitter: ScalarFloat = gpjax.base.static_field(1e-6)\n",
    "    variational_mean: Float[Array, \"N 1\"] | None = gpjax.base.param_field(None)\n",
    "    variational_root_covariance: Float[Array, \"N N\"] = gpjax.base.param_field(\n",
    "        None, bijector=tfp.bijectors.FillTriangular()\n",
    "    )\n",
    "    spherical_harmonic_fields: AbstractSphericalHarmonicFields = gpjax.base.static_field(init=False)\n",
    "    sphere_dim: int = gpjax.base.static_field(2)\n",
    "    num_inducing: int = gpjax.base.static_field(init=False)\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.num_inducing = self.spherical_harmonic_fields.num_fields\n",
    "        # Kzz and muz does not change during optimization\n",
    "        self.muz = jnp.zeros((self.num_inducing, 1))\n",
    "\n",
    "        if self.variational_mean is None:\n",
    "            self.variational_mean = jnp.zeros((self.num_inducing, 1))        \n",
    "\n",
    "        if self.variational_root_covariance is None:\n",
    "            self.variational_root_covariance = jnp.eye(self.num_inducing) + self.jitter\n",
    "\n",
    "    def _repeat_per_phase(self, x: Float[Array, \"N 2\"]) -> Float[Array, \"N 2 I\"]:\n",
    "        return jnp.repeat(\n",
    "            x, \n",
    "            self.spherical_harmonic_fields.num_phases_per_frequency,\n",
    "            total_repeat_length=self.spherical_harmonic_fields.num_phases,\n",
    "        )\n",
    "\n",
    "    @abstractmethod\n",
    "    def ahats(self) -> Float[Array, \"I\"]:\n",
    "        pass  \n",
    "\n",
    "    @jax.jit\n",
    "    def Lz_T_inv_diagonal(self):\n",
    "        ahats = self.ahats()\n",
    "        return jnp.sqrt(ahats / (1 + ahats * self.jitter))\n",
    "\n",
    "\n",
    "    def Kzt(self, t: Float[Array, \"N 2\"]) -> Float[Array, \"N I 2\"]:\n",
    "        r\"\"\"Compute the cross-covariance between the inducing inputs and the test inputs.\n",
    "\n",
    "        Args:\n",
    "            t (Float[Array, \"N 2\"]): The test inputs.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            Float[Array, \"N (2 max_ell + 1)\"]: The cross-covariance between the inducing inputs and the test inputs.\n",
    "        \"\"\"\n",
    "        fields = self.spherical_harmonic_fields(t) # [N 2 I]\n",
    "        return jnp.permute_dims(fields, (0, 1, 2)).reshape(self.num_inducing, -1)\n",
    "    \n",
    "    def prior_kl(self) -> ScalarFloat:\n",
    "        # Unpack variational parameters\n",
    "        mu = self.variational_mean\n",
    "        sqrt = self.variational_root_covariance\n",
    "        sqrt = cola.ops.Triangular(sqrt)\n",
    "\n",
    "        # Unpack mean function and kernel\n",
    "        muz = self.muz # TODO maybe allow non-zero prior mean. This would necessitate setting the first position of the mean to the prior mean constant\n",
    "\n",
    "        S = sqrt @ sqrt.T\n",
    "\n",
    "        qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)\n",
    "        pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()))\n",
    "\n",
    "        return qu.kl_divergence(pu) # TODO efficiency here can be improved by using the fact that Kzz_jittered is diagonal \n",
    "    \n",
    "    def predict(self, test_inputs: Float[Array, \"N D\"]) -> GaussianDistribution:\n",
    "        t = test_inputs\n",
    "\n",
    "        # Unpack variational parameters\n",
    "        mu = self.variational_mean\n",
    "        sqrt = self.variational_root_covariance # [I I]\n",
    "\n",
    "        # Unpack mean function and kernel\n",
    "        mean_function = self.posterior.prior.mean_function\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        # Compute posterior covariance\n",
    "        Ktt = kernel.gram(t) # [2N 2N]\n",
    "        Ktz = self.Kzt(t).mT # [I 2N]\n",
    "        Lz_T_inv = self.Lz_T_inv_diagonal()\n",
    "\n",
    "        Ktz_Lz_T_inv = Ktz * Lz_T_inv\n",
    "        Ktz_Lz_T_inv_sqrt = Ktz_Lz_T_inv @ sqrt # [2N I] @ [I I] -> [2N I]\n",
    "        covariance = (\n",
    "            Ktt \n",
    "            + Ktz_Lz_T_inv_sqrt @ Ktz_Lz_T_inv_sqrt.mT\n",
    "            - Ktz_Lz_T_inv @ Ktz_Lz_T_inv.mT\n",
    "        )\n",
    "        covariance = cola.PSD(covariance + cola.ops.I_like(covariance) * self.jitter) # add jitter for spectral stability\n",
    "\n",
    "        # Compute posterior mean \n",
    "        mut = mean_function(t)\n",
    "        muz = self.muz\n",
    "\n",
    "        mean = (\n",
    "            mut \n",
    "            + Ktz_Lz_T_inv @ (mu - muz) # [2N I] @ [I 1] -> [2N 1]\n",
    "        )\n",
    "\n",
    "        return GaussianDistribution(\n",
    "            loc=jnp.atleast_1d(mean.squeeze()), scale=covariance\n",
    "        )\n",
    "    \n",
    "    @Partial(jax.jit, static_argnames=('num_samples',))\n",
    "    def _pathwise_sample(self, key: Key, test_inputs: Float[Array, \"S N 2\"], num_samples: int) -> Float[Array, \"S N 2\"]:\n",
    "        Ktt_key, S_key = jax.random.split(key)\n",
    "\n",
    "        t = test_inputs # [S N 2]\n",
    "\n",
    "        # Unpack variational parameters\n",
    "        m = self.variational_mean\n",
    "        m = jnp.squeeze(m, axis=-1)\n",
    "        sqrt = self.variational_root_covariance # [I I]\n",
    "\n",
    "        # Unpack mean function and kernel\n",
    "        kernel = self.posterior.prior.kernel\n",
    "\n",
    "        # Compute posterior covariance\n",
    "        w = kernel.sample_weights(Ktt_key, num_samples) # [S I]\n",
    "        Ktt_sample = kernel.pathwise_sample_from_weights(t, w, num_samples) # [S N 2]\n",
    "\n",
    "        Phi_t = jax.vmap(self.spherical_harmonic_fields)(t) # [S N I 2]\n",
    "        Lz_T_inv = self.Lz_T_inv_diagonal() # [I]\n",
    "        tilde_Phi_t = jnp.einsum('snid, i -> snid', Phi_t, Lz_T_inv) # [S N I 2]\n",
    "\n",
    "        S_sample = jax.random.multivariate_normal(\n",
    "            S_key, mean=m, cov=sqrt @ sqrt.T, shape=(num_samples,)\n",
    "        ) # [S I]\n",
    "\n",
    "        covariance_sample = (\n",
    "            Ktt_sample \n",
    "            + jnp.einsum('snid, si -> snd', tilde_Phi_t, S_sample - w)\n",
    "        )\n",
    "\n",
    "        # Compute posterior mean \n",
    "        mean = jnp.einsum('snid, i -> snd', tilde_Phi_t, m)\n",
    "\n",
    "        return mean + covariance_sample\n",
    "    \n",
    "\n",
    "    def pathwise_sample(self, key: Key, test_inputs: Float[Array, \"N D\"], num_samples: int) -> Float[Array, \"S N D\"]:\n",
    "        \"\"\"\n",
    "        Args:   \n",
    "            key: The random key.\n",
    "            test_inputs: The input locations. Can be [N D] or [S N D]. Internally, it is broadcasted to [S N D] at the beginning, \n",
    "            then flattened, then processed, and then reshaped back to [S N D].\n",
    "            num_samples: The number of samples to draw.\n",
    "        \"\"\"\n",
    "        test_inputs_shape = jnp.broadcast_shapes(test_inputs.shape, (num_samples, 1, 1))\n",
    "        test_inputs = jnp.broadcast_to(test_inputs, test_inputs_shape)\n",
    "        return self._pathwise_sample(key, test_inputs, num_samples)\n",
    "\n",
    "\n",
    "\n",
    "class CurlFreeVectorSHF(AbstractVectorSHF):\n",
    "        \n",
    "    def __post_init__(self) -> None:\n",
    "        self.spherical_harmonic_fields = CurlFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)\n",
    "        super().__post_init__()\n",
    "\n",
    "    def ahats(self):\n",
    "        ahats_per_frequency = self.posterior.prior.kernel.spectral_density()[:self.max_ell]\n",
    "        return self._repeat_per_phase(ahats_per_frequency)\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class DivFreeVectorSHF(AbstractVectorSHF):\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.spherical_harmonic_fields = DivFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)\n",
    "        super().__post_init__()\n",
    "\n",
    "    def ahats(self):\n",
    "        ahats_per_frequency = self.posterior.prior.kernel.spectral_density()[:self.max_ell]\n",
    "        return self._repeat_per_phase(ahats_per_frequency)\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class VectorSHF(AbstractVectorSHF):\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.spherical_harmonic_fields = SphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)\n",
    "        super().__post_init__()\n",
    "\n",
    "    def ahats(self):\n",
    "        curl_free_kernel = self.posterior.prior.kernel.curl_free_kernel\n",
    "        div_free_kernel = self.posterior.prior.kernel.div_free_kernel\n",
    "\n",
    "        curl_free_ahats_per_frequency = curl_free_kernel.spectral_density()[:self.max_ell]\n",
    "        div_free_ahats_per_frequency = div_free_kernel.spectral_density()[:self.max_ell]\n",
    "\n",
    "        return jnp.concatenate([\n",
    "            self._repeat_per_phase(curl_free_ahats_per_frequency),\n",
    "            self._repeat_per_phase(div_free_ahats_per_frequency),\n",
    "        ]) \n",
    "    \n",
    "\n",
    "def variational_family_from_kernel(kernel: type[AbstractVectorKernel]) -> type[AbstractVectorSHF]:\n",
    "    if issubclass(kernel, HodgeMaternCurlFreeKernel):\n",
    "        return CurlFreeVectorSHF\n",
    "    elif issubclass(kernel, HodgeMaternDivFreeKernel):\n",
    "        return DivFreeVectorSHF\n",
    "    elif issubclass(kernel, HodgeMaternKernel):\n",
    "        return VectorSHF\n",
    "    else:\n",
    "        raise ValueError(\"Unknown kernel type.\")\n",
    "    \n",
    "\n",
    "from jaxtyping import Key \n",
    "\n",
    "# TODO Should consider double jax.vmap without reshaping and using the batched functionality of MultivariateNormalFullCovariance\n",
    "@jax.jit\n",
    "def sample_from_marginal(\n",
    "    key: Key, \n",
    "    model: gpjax.gps.AbstractPrior | gpjax.variational_families.AbstractVariationalGaussian,\n",
    "    x: Float[Array, \"N D\"] | Float[Array, \"S N D\"],\n",
    ") -> Float[Array, \"S N O\"]:\n",
    "    \"\"\"\n",
    "    Sample from the marginal distribution of the model at the input locations.\n",
    "\n",
    "    Args:\n",
    "        key: The random key.\n",
    "        model: The model object.\n",
    "        x: The input locations. Can be [N D] or [S N D]. Internally, it is broadcasted to [S N D] at the beginning.\n",
    "    \"\"\"\n",
    "\n",
    "    def moments(t: Float[Array, \"D\"]) -> tuple[Float[Array, \"O\"], Float[Array, \"O O\"]]:\n",
    "        pt = model(t)\n",
    "        return pt.mean(), pt.covariance()\n",
    "    \n",
    "    means, covariance_matrices = jax.vmap(jax.vmap(moments))(x[:, :, None]) # [S N O], [S N O O]\n",
    "\n",
    "    # NOTE we should probably add expand to num_samples here\n",
    "    marginal_pt = tfp.distributions.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariance_matrices)\n",
    "    return marginal_pt.sample(seed=key, sample_shape=())\n",
    "\n",
    "\n",
    "EPS = 1e-12\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def tangent_basis(x: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n",
    "    tb = jax.jacfwd(sph_to_car)(x)\n",
    "    tb /= jnp.linalg.norm(tb, axis=0, keepdims=True)\n",
    "    return tb \n",
    "\n",
    "@jax.jit\n",
    "def expmap_car(x: Float[Array, \"3\"], v: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n",
    "    def first_order_taylor():\n",
    "        t = x + v \n",
    "        return t / jnp.linalg.norm(t)\n",
    "\n",
    "    theta = jnp.linalg.norm(v)\n",
    "    return jax.lax.cond(\n",
    "        theta < EPS,\n",
    "        first_order_taylor,\n",
    "        lambda: jnp.cos(theta) * x + jnp.sin(theta) * v / theta,\n",
    "    )\n",
    "\n",
    "\n",
    "@Partial(jax.jit, static_argnames=(\"colatitude_min_value\", ))\n",
    "def expmap_sph(x: Float[Array, \"D\"], v: Float[Array, \"D\"], colatitude_min_value: float = EPS) -> Float[Array, \"D\"]:\n",
    "    \"\"\"\n",
    "    Exponential map on the sphere taking x in spherical coordinates and v in the 'canonical' coordinate frame. \n",
    "    This function internally ensures that the colatitude of x is not too small to avoid nans.\n",
    "    \"\"\"\n",
    "    x = _ensure_colatitude_nonzero(x, colatitude_min_value)\n",
    "    x_prime = sph_to_car(x)\n",
    "    v_prime = tangent_basis(x) @ v\n",
    "    return car_to_sph(expmap_car(x_prime, v_prime))\n",
    "\n",
    "\n",
    "from dataclasses import InitVar\n",
    "\n",
    "\n",
    "@dataclass \n",
    "class IdentityPosterior(gpjax.gps.AbstractPosterior):\n",
    "    likelihood: None = gpjax.base.static_field(None)\n",
    "\n",
    "    def predict(self, test_inputs: Float[Array, \"N D\"]) -> GaussianDistribution:\n",
    "        return self.prior(test_inputs)\n",
    "    \n",
    "\n",
    "@dataclass \n",
    "class AbstractDeepGP(gpjax.Module):\n",
    "    layers: list[gpjax.variational_families.AbstractVariationalGaussian] = gpjax.base.param_field(init=True)\n",
    "    num_samples: int = gpjax.base.static_field(10)\n",
    "    num_layers: int = gpjax.base.static_field(init=False)\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.num_layers = len(self.layers)\n",
    "\n",
    "    @property \n",
    "    def hidden_layers(self) -> list[gpjax.variational_families.AbstractVariationalGaussian]:\n",
    "        return self.layers[:-1]\n",
    "    \n",
    "    @property\n",
    "    def output_layer(self) -> gpjax.variational_families.AbstractVariationalGaussian:\n",
    "        return self.layers[-1]\n",
    "\n",
    "    def prior_kl(self) -> ScalarFloat:\n",
    "        return sum(layer.prior_kl() for layer in self.layers)\n",
    "    \n",
    "    @abstractmethod\n",
    "    def sample_from_hidden(self, key: Key, x: Float[Array, \"N D\"]) -> Float[Array, \"S N D\"]:\n",
    "        pass\n",
    "\n",
    "    def output_predict(self, x: Float[Array, \"S N D\"]) -> tfd.MixtureSameFamily:\n",
    "        \"\"\"\n",
    "        Predict through the output layer. \n",
    "\n",
    "        Args:\n",
    "            x (Float[Array, \"S N D\"]): The input data. \n",
    "        \"\"\"\n",
    "        def moments(t: Float[Array, \"N D\"]) -> tuple[Float[Array, \"N\"], Float[Array, \"N N\"]]:\n",
    "            pt = self.output_layer(t)\n",
    "            return pt.mean(), pt.covariance()\n",
    "        \n",
    "        means, covariance_matrices = jax.vmap(moments)(x)\n",
    "        return tfd.MixtureSameFamily(\n",
    "            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)),\n",
    "            components_distribution=tfd.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariance_matrices),\n",
    "        )\n",
    "    \n",
    "    def predict(self, key: Key, x: Float[Array, \"N D\"]) -> tfd.MixtureSameFamily:\n",
    "        \"\"\"\n",
    "        Predict through the entire model. \n",
    "\n",
    "        Args:\n",
    "            x (Float[Array, \"N D\"]): The input data. \n",
    "        \"\"\"\n",
    "        return self.output_predict(self.sample_from_hidden(key, x))\n",
    "    \n",
    "    def __call__(self, key: Key, x: Float[Array, \"N D\"]) -> GaussianDistribution:\n",
    "        raise self.predict(key, x)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ResidualDeepGP(AbstractDeepGP):\n",
    "\n",
    "    def sample_from_hidden(self, key: Key, x: Float[Array, \"N D\"]) -> Float[Array, \"S N D\"]:\n",
    "        \"\"\"\n",
    "        Predict through the hidden layers.\n",
    "\n",
    "        Args:\n",
    "            x (Float[Array, \"N D\"]): The input data. Either of shape [N D] or [S N D].\n",
    "        \"\"\"\n",
    "        x_shape = jnp.broadcast_shapes(x.shape, (self.num_samples, 1, 1))\n",
    "        x = jnp.broadcast_to(x, x_shape)\n",
    "\n",
    "        def step(key: Key, layer, x: Array) -> Array:\n",
    "            v = sample_from_marginal(key=key, model=layer, x=x)\n",
    "            return jax.vmap(jax.vmap(expmap_sph, in_axes=(0, 0)), in_axes=(0, 0))(x, v)\n",
    "\n",
    "        key_per_hidden_layer = jax.random.split(key, self.num_layers - 1)\n",
    "        for layer, key in zip(self.hidden_layers, key_per_hidden_layer):\n",
    "            x = step(key, layer, x)\n",
    "        return x\n",
    "    \n",
    "    def pathwise_sample_from_hidden(self, key: Key, x: Float[Array, \"N D\"]) -> Float[Array, \"S N D\"]:\n",
    "        \"\"\"\n",
    "        Predict through the hidden layers.\n",
    "\n",
    "        Args:\n",
    "            x (Float[Array, \"N D\"]): The input data. Either of shape [N D] or [S N D].\n",
    "        \"\"\"\n",
    "        x_shape = jnp.broadcast_shapes(x.shape, (self.num_samples, 1, 1))\n",
    "        x = jnp.broadcast_to(x, x_shape)\n",
    "        def step(key, layer, x: Array) -> Array:\n",
    "            v = layer.pathwise_sample(key, x, self.num_samples)\n",
    "            return jax.vmap(jax.vmap(expmap_sph, in_axes=(0, 0)), in_axes=(0, 0))(x, v)\n",
    "\n",
    "        key_per_hidden_layer = jax.random.split(key, self.num_layers - 1)\n",
    "        for layer, key in zip(self.hidden_layers, key_per_hidden_layer):\n",
    "            x = step(key, layer, x)\n",
    "        return x\n",
    "    \n",
    "    def pathwise_sample(self, key: Key, x: Float[Array, \"N D\"]) -> Float[Array, \"S N D\"]:\n",
    "        hidden_key, output_key = jax.random.split(key)\n",
    "\n",
    "        x = self.pathwise_sample_from_hidden(hidden_key, x)\n",
    "        return self.output_layer.pathwise_sample(output_key, x, self.num_samples)\n",
    "    \n",
    "\n",
    "\n",
    "class DeepVectorELBO(gpjax.objectives.AbstractObjective):\n",
    "    def step(\n",
    "        self,\n",
    "        key: Key, \n",
    "        variational_family: AbstractDeepGP,\n",
    "        train_data: gpjax.Dataset,\n",
    "    ) -> ScalarFloat:\n",
    "        r\"\"\"Compute the evidence lower bound of a variational approximation.\n",
    "\n",
    "        Compute the evidence lower bound under this model. In short, this requires\n",
    "        evaluating the expectation of the model's log-likelihood under the variational\n",
    "        approximation. To this, we sum the KL divergence from the variational posterior\n",
    "        to the prior. When batching occurs, the result is scaled by the batch size\n",
    "        relative to the full dataset size.\n",
    "\n",
    "        Args:\n",
    "            variational_family (AbstractVariationalFamily): The variational\n",
    "                approximation for whose parameters we should maximise the ELBO with\n",
    "                respect to.\n",
    "            train_data (Dataset): The training data for which we should maximise the\n",
    "                ELBO with respect to.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "            ScalarFloat: The evidence lower bound of the variational approximation for\n",
    "                the current model parameter set.\n",
    "        \"\"\"\n",
    "        # KL[q(f(·)) || p(f(·))]\n",
    "        kl = variational_family.prior_kl()\n",
    "\n",
    "        # ∫[log(p(y|f(·))) q(f(·))] df(·)\n",
    "        var_exp = deep_vector_variational_expectation(key, variational_family, train_data)\n",
    "\n",
    "        # For batch size b, we compute  n/b * Σᵢ[ ∫log(p(y|f(xᵢ))) q(f(xᵢ)) df(xᵢ)] - KL[q(f(·)) || p(f(·))]\n",
    "        return self.constant * (\n",
    "            jnp.sum(var_exp)\n",
    "            * variational_family.output_layer.posterior.likelihood.num_datapoints\n",
    "            / train_data.n\n",
    "            - kl\n",
    "        )\n",
    "\n",
    "\n",
    "@jax.jit \n",
    "def moments(model: AbstractVectorSHF, x: Array) -> tuple[Array, Array]:\n",
    "    def mean_and_covariance(x):\n",
    "        pf = model(x)\n",
    "        py = model.posterior.likelihood(pf) # FIXME This won't work for the prior \n",
    "        return py.mean(), py.covariance()\n",
    "    return jax.vmap(mean_and_covariance)(x[:, None])\n",
    "\n",
    "\n",
    "def deep_vector_variational_expectation(\n",
    "    key: Key, \n",
    "    variational_family: AbstractDeepGP,\n",
    "    train_data: gpjax.Dataset,\n",
    ") -> Float[Array, \" N\"]:\n",
    "    r\"\"\"Compute the variational expectation.\n",
    "\n",
    "    Compute the expectation of our model's log-likelihood under our variational\n",
    "    distribution. Batching can be done here to speed up computation.\n",
    "\n",
    "    Args:\n",
    "        variational_family (AbstractVariationalFamily): The variational family that we\n",
    "            are using to approximate the posterior.\n",
    "        train_data (Dataset): The batch for which the expectation should be computed\n",
    "            for.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "        Array: The expectation of the model's log-likelihood under our variational\n",
    "            distribution.\n",
    "    \"\"\"\n",
    "    # Unpack training batch\n",
    "    x, y = train_data.X, train_data.y # [N, D] [N, O]\n",
    "\n",
    "    # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·))\n",
    "    q = variational_family\n",
    "    x = q.sample_from_hidden(key, x)\n",
    "\n",
    "    # reshape because samples \n",
    "    num_samples = x.shape[0]\n",
    "    y = jnp.broadcast_to(y, (num_samples, *y.shape)).reshape(-1, y.shape[-1])\n",
    "    x = x.reshape(-1, x.shape[-1]) # [S N D] -> [S * N D]\n",
    "\n",
    "    # Compute variational mean, μ(x), and variance, diag(Σ(x, x)), at the training\n",
    "    # inputs, x\n",
    "    mean, covariance = moments(q.output_layer, x) # [S * N O], [S * N O O]\n",
    "\n",
    "    # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x)\n",
    "    # There is no need to handle likelihoods of different samples in some special way, \n",
    "    # since likelihood of mixture is the mixture of likelihoods\n",
    "    expectation = q.output_layer.posterior.likelihood.expected_log_likelihood(\n",
    "        y, mean, covariance\n",
    "    )\n",
    "    return expectation / num_samples # MC estimate of the inner expectation requires dividing by the number of samples\n",
    "\n",
    "\n",
    "def moments_unconditional(model, x):\n",
    "    def mean_and_covariance(x):\n",
    "        pf = model(x)\n",
    "        py = model.posterior.likelihood(pf) # FIXME This won't work for the prior \n",
    "        return py.mean(), py.covariance()\n",
    "    return jax.vmap(mean_and_covariance)(x[:, None])\n",
    "\n",
    "\n",
    "def moments_deep(key: Key, model: AbstractDeepGP, x):\n",
    "    x = model.sample_from_hidden(key, x)\n",
    "    means, covs = jax.vmap(lambda t: moments_unconditional(model.output_layer, t))(x) # map over sample dimension\n",
    "    return means, covs \n",
    "\n",
    "\n",
    "def pathwise_moments_deep(key: Key, model: AbstractDeepGP, x):\n",
    "    x = model.pathwise_sample_from_hidden(key, x)\n",
    "    means, covs = jax.vmap(lambda t: moments_unconditional(model.output_layer, t))(x) # map over sample dimension\n",
    "    return means, covs\n",
    "\n",
    "\n",
    "def mse(y_true: Array, y_pred: Array) -> Array:\n",
    "    return jnp.mean(jnp.sum(jnp.square(y_true - y_pred), axis=-1))\n",
    "\n",
    "\n",
    "def pred_nll(y_true, y_pred, std_pred):\n",
    "    return -jnp.mean(\n",
    "        tfd.MultivariateNormalFullCovariance(loc=y_pred, covariance_matrix=std_pred).log_prob(y_true)\n",
    "    )\n",
    "\n",
    "\n",
    "def evaluate(key: Key, model, test_data: VectorDataset):\n",
    "    x_test, y_test = test_data.X, test_data.y\n",
    "    mean, cov = moments_deep(key, model, x_test)\n",
    "    return {\n",
    "        'mse': mse(y_test, mean).item(), \n",
    "        'pnll': pred_nll(y_test, mean, cov).item(),\n",
    "    }\n",
    "\n",
    "\n",
    "import jax \n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import netCDF4\n",
    "\n",
    "from jaxtyping import Array\n",
    "from jax.tree_util import Partial \n",
    "import plotly.express as px \n",
    "\n",
    "\n",
    "def sphere_uniform_grid(n: int) -> Array:\n",
    "    # Fibonacci lattice method \n",
    "    phi = (1 + jnp.sqrt(5)) / 2  # Golden ratio\n",
    "    \n",
    "    indices = jnp.arange(n)\n",
    "    theta = 2 * jnp.pi * indices / phi\n",
    "    phi = jnp.arccos(1 - 2 * (indices + 0.5) / n)\n",
    "    return sph_to_car(jnp.column_stack((phi, theta)))\n",
    "\n",
    "\n",
    "@jax.tree_util.Partial(jax.jit, static_argnames=(\"with_replacement\",))\n",
    "def closest_point_mask(targets: Array, x: Array, with_replacement: bool) -> Array:\n",
    "    \"\"\"\n",
    "    Args: \n",
    "        targets (Array): targets in cartesian coordinates.\n",
    "        x (Array): points in cartesian coordinates for which to produce the mask.\n",
    "    \"\"\"\n",
    "\n",
    "    # Can do euclidean squared distance instead of spherical, since minimisation is invariant to monotonic transformations\n",
    "    distances = jnp.sum((targets[:, None] - x[None, :]) ** 2, -1)\n",
    "\n",
    "    def closest_point_mask_with_replacement():\n",
    "        return jnp.argmin(distances, axis=1)\n",
    "    \n",
    "    def closest_point_mask_without_replacement():\n",
    "        num_targets = targets.shape[0]\n",
    "        closest_indices = jnp.zeros(num_targets, dtype=jnp.int64)\n",
    "        available_mask = jnp.ones(x.shape[0], dtype=bool)\n",
    "\n",
    "        for i in range(num_targets):\n",
    "            masked_distances = jnp.where(available_mask, distances[i], jnp.inf)\n",
    "            closest_idx = jnp.argmin(masked_distances)\n",
    "            closest_indices = closest_indices.at[i].set(closest_idx)\n",
    "            available_mask = available_mask.at[closest_idx].set(False)\n",
    "\n",
    "        return closest_indices\n",
    "    \n",
    "    mask_indices = jax.lax.cond(\n",
    "        with_replacement, \n",
    "        closest_point_mask_with_replacement, \n",
    "        closest_point_mask_without_replacement,\n",
    "    )\n",
    "    mask = jnp.zeros(x.shape[0], dtype=jnp.bool)\n",
    "    return mask.at[mask_indices].set(True)\n",
    "\n",
    "\n",
    "def angles_to_radians_colat(x: Array) -> Array:\n",
    "    return jnp.pi * x / 180 + jnp.pi / 2\n",
    "\n",
    "def angles_to_radians_lon(x: Array) -> Array:\n",
    "    return jnp.pi * x / 180 \n",
    "\n",
    "def angles_to_radians(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    return df.assign(\n",
    "        colat=lambda df: angles_to_radians_colat(df.colat),\n",
    "        lon=lambda df: angles_to_radians_lon(df.lon),\n",
    "    )\n",
    "\n",
    "def radians_to_angles_colat(x: Array) -> Array:\n",
    "    return 180 * x / jnp.pi - 90 \n",
    "\n",
    "def radians_to_angles_lon(x: Array) -> Array:\n",
    "    return 180 * x / jnp.pi \n",
    "\n",
    "def radians_to_angles(df: pd.DataFrame) -> pd.DataFrame:\n",
    "    return df.assign(\n",
    "        colat=lambda df: radians_to_angles_colat(df.colat),\n",
    "        lon=lambda df: radians_to_angles_lon(df.lon),\n",
    "    )\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def sph_to_car(sph: Array) -> Array:\n",
    "    \"\"\"\n",
    "    Args: \n",
    "        sph (Array): points in spherical coordinates.\n",
    "    \"\"\"\n",
    "    colat, lon = sph[..., 0], sph[..., 1]\n",
    "    z = jnp.cos(colat)\n",
    "    r = jnp.sin(colat)\n",
    "    x = r * jnp.cos(lon)\n",
    "    y = r * jnp.sin(lon)\n",
    "    return jnp.stack([x, y, z], axis=-1)\n",
    "\n",
    "\n",
    "def sphere_meshgrid(n: int) -> Float:\n",
    "    \"\"\"\n",
    "    Create a meshgrid on the sphere.\n",
    "    \"\"\"\n",
    "    colat = jnp.linspace(0, jnp.pi, n)\n",
    "    lon = jnp.linspace(0, 2 * jnp.pi, n)\n",
    "    colat, lon = jnp.meshgrid(colat, lon, indexing=\"ij\")\n",
    "    return jnp.stack([colat, lon], axis=-1)\n",
    "\n",
    "\n",
    "era5_file_path = \"../data/era5.nc\"\n",
    "era5_dataset = netCDF4.Dataset(era5_file_path,'r')\n",
    "era5_lon = angles_to_radians_lon(era5_dataset.variables['longitude'][:].data.astype(np.float64))\n",
    "era5_colat = angles_to_radians_colat(era5_dataset.variables['latitude'][:].data.astype(np.float64))\n",
    "era5_lon_mesh, era5_colat_mesh = jnp.meshgrid(era5_lon, era5_colat)\n",
    "\n",
    "\n",
    "def read_era5(time: int, level: int) -> pd.DataFrame:\n",
    "    level = {\n",
    "        0: 0, \n",
    "        7: 1, \n",
    "        15: 2, \n",
    "    }[level]\n",
    "    u = era5_dataset.variables['u'][time, level].data.astype(np.float64)\n",
    "    v = era5_dataset.variables['v'][time, level].data.astype(np.float64)\n",
    "    df = pd.DataFrame({\n",
    "        \"lon\": era5_lon_mesh.flatten(),\n",
    "        \"colat\": era5_colat_mesh.flatten(),\n",
    "        \"u\": u.flatten(),\n",
    "        \"v\": v.flatten(),\n",
    "    })\n",
    "    return df\n",
    "\n",
    "\n",
    "def match_to_uniform_grid_mask(x: Array, n: int, with_replacement: bool = True) -> Array:\n",
    "    \"\"\"\n",
    "    Args: \n",
    "        x (Array): Points in cartesian coordinates for which to create the mask.\n",
    "    \"\"\"\n",
    "    return closest_point_mask(\n",
    "        targets=sphere_uniform_grid(n),\n",
    "        x=x,\n",
    "        with_replacement=with_replacement,\n",
    "    )\n",
    "\n",
    "\n",
    "def to_test_dataframe(df: pd.DataFrame, n: int, with_replacement: bool = True) -> pd.DataFrame:\n",
    "    sph = df[['colat', 'lon']].values\n",
    "    mask = match_to_uniform_grid_mask(\n",
    "        x=sph_to_car(sph), n=n, with_replacement=with_replacement,\n",
    "    ).tolist()\n",
    "    return df[mask]\n",
    "\n",
    "\n",
    "\n",
    "import pandas as pd \n",
    "import math \n",
    "from datetime import datetime, timedelta\n",
    "from skyfield.api import load, EarthSatellite, utc\n",
    "from skyfield.toposlib import wgs84\n",
    "\n",
    "\n",
    "def datetime_range(start, stop, step=timedelta(minutes=1)):\n",
    "    current = start\n",
    "    while current < stop:\n",
    "        yield current\n",
    "        current += step\n",
    "\n",
    "\n",
    "def load_aeolus_and_timescale():\n",
    "    ts = load.timescale()\n",
    "\n",
    "    # Aeolus TLE data\n",
    "    line1 = \"1 43600U 18066A   21153.73585495  .00031128  00000-0  12124-3 0  9990\"\n",
    "    line2 = \"2 43600  96.7150 160.8035 0006915  90.4181 269.7884 15.87015039160910\"\n",
    "\n",
    "    aeolus = EarthSatellite(line1, line2, \"AEOLUS\", ts)\n",
    "    return aeolus, ts\n",
    "\n",
    "\n",
    "def read_aeolus(start: datetime, stop: datetime, step=timedelta(minutes=1)) -> pd.DataFrame:\n",
    "    if start.tzinfo is None:\n",
    "        start = start.replace(tzinfo=utc)\n",
    "    if stop.tzinfo is None:\n",
    "        stop = stop.replace(tzinfo=utc)\n",
    "\n",
    "    aeolus, ts = load_aeolus_and_timescale()\n",
    "    time = list(datetime_range(start, stop, step))\n",
    "    lat, lon = wgs84.latlon_of(aeolus.at(ts.from_datetimes(time)))\n",
    "\n",
    "    # convert to colatitude [0, pi] and longitude [0, 2pi]\n",
    "    colat, lon = lat.radians + math.pi / 2, lon.radians + math.pi\n",
    "\n",
    "    return pd.DataFrame({\n",
    "        \"time\": time,\n",
    "        \"colat\": colat,\n",
    "        \"lon\": lon,\n",
    "    })\n",
    "\n",
    "def to_train_dataframe(aeolus: pd.DataFrame, era5: pd.DataFrame, with_replacement: bool = True) -> tuple[pd.DataFrame, pd.DataFrame]:\n",
    "    targets = sph_to_car(aeolus[['colat', 'lon']].values)\n",
    "    x = sph_to_car(era5[['colat', 'lon']].values)\n",
    "    mask = closest_point_mask(\n",
    "        targets=targets, \n",
    "        x=x, \n",
    "        with_replacement=with_replacement,\n",
    "    ).tolist()\n",
    "    return era5[mask], era5[~np.array(mask)]\n",
    "\n",
    "\n",
    "def to_train_test_dataframes(aeolus: pd.DataFrame, era5: pd.DataFrame, test_size: int, with_replacement: bool = True) -> tuple[pd.DataFrame, pd.DataFrame]:\n",
    "    train_df, rest_df = to_train_dataframe(aeolus=aeolus, era5=era5, with_replacement=with_replacement)\n",
    "    test_df = to_test_dataframe(rest_df, n=test_size, with_replacement=with_replacement)\n",
    "    return train_df, test_df\n",
    "\n",
    "\n",
    "def train_test_sets(\n",
    "    time: int, \n",
    "    level: int, \n",
    "    start: datetime, \n",
    "    stop: datetime, \n",
    "    step: timedelta, \n",
    "    test_size: int, \n",
    "    with_replacement: bool = True,\n",
    ") -> tuple[Array, Array, Array, Array]:\n",
    "    aeolus = read_aeolus(start=start, stop=stop, step=step)\n",
    "    era5 = read_era5(time, level)\n",
    "    \n",
    "    # split data\n",
    "    df_train, df_test = to_train_test_dataframes(\n",
    "        aeolus=aeolus, era5=era5, test_size=test_size, with_replacement=with_replacement\n",
    "    )\n",
    "\n",
    "    # Inputs and target\n",
    "    X_train, X_test = df_train[[\"colat\", \"lon\"]].to_numpy(), df_test[[\"colat\", \"lon\"]].to_numpy()\n",
    "    y_train, y_test = df_train[[\"v\", \"u\"]].to_numpy(), df_test[[\"v\", \"u\"]].to_numpy()\n",
    "\n",
    "    # Convert to jnp arrays (not sure if this is necessary)\n",
    "    X_train, X_test = jnp.array(X_train), jnp.array(X_test)\n",
    "    y_train, y_test = jnp.array(y_train), jnp.array(y_test)\n",
    "\n",
    "    # Normalize (sort of) targets\n",
    "    norm_constant = jnp.mean(jax.vmap(jnp.linalg.norm)(y_train))\n",
    "    y_train /= norm_constant\n",
    "    y_test /= norm_constant\n",
    "    return X_train, X_test, y_train, y_test\n",
    "\n",
    "\n",
    "def build_layers(\n",
    "    num_layers: int,\n",
    "    hidden_kernel: type[AbstractVectorKernel], \n",
    "    output_kernel: type[AbstractVectorKernel],\n",
    "    likelihood: VectorGaussian,\n",
    "    hidden_variance: float = 0.01,\n",
    "    kappa: float = 1.0,\n",
    "    max_ell_variational: int = 9,\n",
    "    max_ell_prior: int = 30, \n",
    ") -> list[AbstractVectorSHF]:    \n",
    "    layers = []\n",
    "\n",
    "    # hidden layers \n",
    "    hidden_variational_family = variational_family_from_kernel(hidden_kernel)\n",
    "    for _ in range(num_layers - 1):\n",
    "        kernel = hidden_kernel(variance=hidden_variance, max_ell=max_ell_prior, kappa=kappa)\n",
    "        mean_function = VectorZeroMean()\n",
    "        prior = gpjax.gps.Prior(kernel=kernel, mean_function=mean_function)\n",
    "        posterior = IdentityPosterior(prior=prior)\n",
    "        layer = hidden_variational_family(posterior=posterior, max_ell=max_ell_variational)\n",
    "        layers.append(layer)\n",
    "    \n",
    "    # output layer\n",
    "    output_variational_family = variational_family_from_kernel(output_kernel)\n",
    "    kernel = output_kernel(max_ell=max_ell_prior, kappa=kappa)\n",
    "    mean_function = VectorZeroMean()\n",
    "    prior = gpjax.gps.Prior(kernel=kernel, mean_function=mean_function)\n",
    "    posterior = prior * likelihood\n",
    "    layer = output_variational_family(posterior=posterior, max_ell=max_ell_variational)\n",
    "    layers.append(layer)\n",
    "\n",
    "    return layers\n",
    "\n",
    "\n",
    "def plot_results(x_train, y_train, x_test, y_test, mean, var, var_x, sample, history):\n",
    "    var_lat, var_lon = var_x[..., 0], var_x[..., 1]\n",
    "    x_train_lat, x_train_lon = x_train[:, 0], x_train[:, 1]\n",
    "    y_train_dlat, y_train_dlon = y_train[:, 0], y_train[:, 1]\n",
    "    x_test_lat, x_test_lon = x_test[:, 0], x_test[:, 1]\n",
    "    y_test_dlat, y_test_dlon = y_test[:, 0], y_test[:, 1]\n",
    "    mean_dlat, mean_dlon = mean[:, 0], mean[:, 1]\n",
    "\n",
    "\n",
    "    nrows = 3\n",
    "    ncols = 2\n",
    "    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 4), layout=\"constrained\")\n",
    "\n",
    "    \n",
    "    # top left, prediction with test data \n",
    "    q = axs[0][0].quiver(x_test_lon, x_test_lat, mean_dlon, mean_dlat, angles=\"uv\")\n",
    "    q._init()\n",
    "    scale = q.scale\n",
    "    axs[0][0].quiver(x_train_lon, x_train_lat, y_train_dlon, y_train_dlat, angles=\"uv\", color=\"red\", scale=scale)\n",
    "    axs[0][0].set_xlabel(\"lon\")\n",
    "    axs[0][0].set_ylabel(\"lat\")\n",
    "    axs[0][0].set_title(\"Predictive Mean\")\n",
    "    \n",
    "    # top right, uncertainty\n",
    "    c = axs[0][1].pcolormesh(var_lon, var_lat, var, vmin=var.min(), vmax=var.max())\n",
    "    fig.colorbar(c, ax=axs[0][1])\n",
    "    axs[0][1].set_xlabel(\"lon\")\n",
    "    axs[0][1].set_ylabel(\"lat\")\n",
    "    axs[0][1].set_title(\"Predictive Uncertainty\")\n",
    "\n",
    "    # middle left, true test data\n",
    "    axs[1][0].quiver(x_test_lon, x_test_lat, y_test_dlon, y_test_dlat, angles=\"uv\", scale=scale)\n",
    "    axs[1][0].set_xlabel(\"lon\")\n",
    "    axs[1][0].set_ylabel(\"lat\")\n",
    "    axs[1][0].set_title(\"Ground truth\")\n",
    "\n",
    "    # middle right, difference \n",
    "    y_diff = y_test - mean\n",
    "    y_diff_dlat, y_diff_dlon = y_diff[:, 0], y_diff[:, 1]\n",
    "    axs[1][1].quiver(x_test_lon, x_test_lat, y_diff_dlon, y_diff_dlat, angles=\"uv\", scale=scale)\n",
    "    axs[1][1].set_xlabel(\"lon\")\n",
    "    axs[1][1].set_ylabel(\"lat\")\n",
    "    axs[1][1].set_title(\"Prediction Error\")\n",
    "\n",
    "    # bottom left, sample from posterior \n",
    "    q = axs[2][0].quiver(x_test_lon, x_test_lat, sample[:, 0], sample[:, 1], angles=\"uv\", scale=scale)\n",
    "    axs[2][0].set_xlabel(\"lon\")\n",
    "    axs[2][0].set_ylabel(\"lat\")\n",
    "    axs[2][0].set_title(\"Sample from Posterior\")\n",
    "\n",
    "    # bottom right, training history\n",
    "    axs[2][1].plot(history)\n",
    "    axs[2][1].set_xlabel(\"Iteration\")\n",
    "    axs[2][1].set_ylabel(\"Negative ELBO\")\n",
    "    axs[2][1].set_title(\"Training History\")\n",
    "\n",
    "    # plt.tight_layout()\n",
    "    return fig\n",
    "\n",
    "\n",
    "# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "\n",
    "from gpjax.fit import (\n",
    "    _check_batch_size,\n",
    "    _check_log_rate,\n",
    "    _check_model,\n",
    "    _check_num_iters,\n",
    "    _check_optim,\n",
    "    _check_train_data,\n",
    "    _check_verbose,\n",
    "    get_batch,\n",
    ")\n",
    "from beartype.typing import (\n",
    "    Any,\n",
    "    Callable,\n",
    "    Optional,\n",
    "    Tuple,\n",
    "    TypeVar,\n",
    "    Union,\n",
    ")\n",
    "import jax\n",
    "from jax import (\n",
    "    jit,\n",
    "    value_and_grad,\n",
    ")\n",
    "from jax._src.random import _check_prng_key\n",
    "from jax.flatten_util import ravel_pytree\n",
    "import jax.numpy as jnp\n",
    "import jax.random as jr\n",
    "import optax as ox\n",
    "import scipy\n",
    "\n",
    "from gpjax.base import Module\n",
    "from gpjax.dataset import Dataset\n",
    "from gpjax.objectives import AbstractObjective\n",
    "from gpjax.scan import vscan\n",
    "from gpjax.typing import (\n",
    "    Array,\n",
    "    KeyArray,\n",
    "    ScalarFloat,\n",
    ")\n",
    "\n",
    "ModuleModel = TypeVar(\"ModuleModel\", bound=Module)\n",
    "\n",
    "\n",
    "def fit_deep(  # noqa: PLR0913\n",
    "    *,\n",
    "    model: ModuleModel,\n",
    "    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],\n",
    "    train_data: Dataset,\n",
    "    optim: ox.GradientTransformation,\n",
    "    key: KeyArray,\n",
    "    num_iters: Optional[int] = 100,\n",
    "    batch_size: Optional[int] = -1,\n",
    "    log_rate: Optional[int] = 10,\n",
    "    verbose: Optional[bool] = True,\n",
    "    unroll: Optional[int] = 1,\n",
    "    safe: Optional[bool] = True,\n",
    ") -> Tuple[ModuleModel, Array]:\n",
    "    r\"\"\"Train a Module model with respect to a supplied Objective function.\n",
    "    Optimisers used here should originate from Optax.\n",
    "\n",
    "    Example:\n",
    "    ```python\n",
    "        >>> import jax.numpy as jnp\n",
    "        >>> import jax.random as jr\n",
    "        >>> import optax as ox\n",
    "        >>> import gpjax as gpx\n",
    "        >>>\n",
    "        >>> # (1) Create a dataset:\n",
    "        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]\n",
    "        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)\n",
    "        >>> D = gpx.Dataset(X, y)\n",
    "        >>>\n",
    "        >>> # (2) Define your model:\n",
    "        >>> class LinearModel(gpx.base.Module):\n",
    "                weight: float = gpx.base.param_field()\n",
    "                bias: float = gpx.base.param_field()\n",
    "\n",
    "                def __call__(self, x):\n",
    "                    return self.weight * x + self.bias\n",
    "\n",
    "        >>> model = LinearModel(weight=1.0, bias=1.0)\n",
    "        >>>\n",
    "        >>> # (3) Define your loss function:\n",
    "        >>> class MeanSquareError(gpx.objectives.AbstractObjective):\n",
    "                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:\n",
    "                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)\n",
    "        >>>\n",
    "        >>> loss = MeanSqaureError()\n",
    "        >>>\n",
    "        >>> # (4) Train!\n",
    "        >>> trained_model, history = gpx.fit(\n",
    "                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000\n",
    "            )\n",
    "    ```\n",
    "\n",
    "    Args:\n",
    "        model (Module): The model Module to be optimised.\n",
    "        objective (Objective): The objective function that we are optimising with\n",
    "            respect to.\n",
    "        train_data (Dataset): The training data to be used for the optimisation.\n",
    "        optim (GradientTransformation): The Optax optimiser that is to be used for\n",
    "            learning a parameter set.\n",
    "        num_iters (Optional[int]): The number of optimisation steps to run. Defaults\n",
    "            to 100.\n",
    "        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1\n",
    "            (i.e. full batch).\n",
    "        key (Optional[KeyArray]): The random key to use for the optimisation batch\n",
    "            selection. Defaults to jr.key(42).\n",
    "        log_rate (Optional[int]): How frequently the objective function's value should\n",
    "            be printed. Defaults to 10.\n",
    "        verbose (Optional[bool]): Whether to print the training loading bar. Defaults\n",
    "            to True.\n",
    "        unroll (int): The number of unrolled steps to use for the optimisation.\n",
    "            Defaults to 1.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "        Tuple[Module, Array]: A Tuple comprising the optimised model and training\n",
    "            history respectively.\n",
    "    \"\"\"\n",
    "    if safe:\n",
    "        # Check inputs.\n",
    "        _check_model(model)\n",
    "        _check_train_data(train_data)\n",
    "        _check_optim(optim)\n",
    "        _check_num_iters(num_iters)\n",
    "        _check_batch_size(batch_size)\n",
    "        _check_prng_key(\"fit\", key)\n",
    "        _check_log_rate(log_rate)\n",
    "        _check_verbose(verbose)\n",
    "\n",
    "    # Unconstrained space loss function with stop-gradient rule for non-trainable params.\n",
    "    def loss(key: Key, model: Module, batch: Dataset) -> ScalarFloat:\n",
    "        model = model.stop_gradient()\n",
    "        return objective(key, model.constrain(), batch)\n",
    "\n",
    "    # Unconstrained space model.\n",
    "    model = model.unconstrain()\n",
    "\n",
    "    # Initialise optimiser state.\n",
    "    state = optim.init(model)\n",
    "\n",
    "    # Mini-batch random keys to scan over.\n",
    "    iter_keys = jr.split(key, num_iters)\n",
    "\n",
    "    # Optimisation step.\n",
    "    def step(carry, key):\n",
    "        model, opt_state = carry\n",
    "\n",
    "        if batch_size != -1:\n",
    "            batch = get_batch(train_data, batch_size, key)\n",
    "        else:\n",
    "            batch = train_data\n",
    "\n",
    "        loss_val, loss_gradient = jax.value_and_grad(loss, argnums=1)(key, model, batch)\n",
    "        updates, opt_state = optim.update(loss_gradient, opt_state, model)\n",
    "        model = ox.apply_updates(model, updates)\n",
    "\n",
    "        carry = model, opt_state\n",
    "        return carry, loss_val\n",
    "\n",
    "    # Optimisation scan.\n",
    "    scan = vscan if verbose else jax.lax.scan\n",
    "\n",
    "    # Optimisation loop.\n",
    "    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)\n",
    "\n",
    "    # Constrained space.\n",
    "    model = model.constrain()\n",
    "\n",
    "    return model, history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "seed = 0\n",
    "time = 7\n",
    "level = 15\n",
    "max_ell_prior = 9\n",
    "max_ell_variational = 9\n",
    "total_hidden_variance = 0.0001\n",
    "lr = 0.01\n",
    "num_iters = 1000\n",
    "num_layers = 3\n",
    "num_samples = 3\n",
    "test_size = 5000\n",
    "batch_size = -1\n",
    "save_dir = \".\"\n",
    "num_hours = 24\n",
    "step_minutes = 1\n",
    "num_test_samples = 10\n",
    "\n",
    "\n",
    "### RANDOMNESS \n",
    "key = jax.random.key(seed)\n",
    "data_key, train_key, test_key, plot_key = jax.random.split(key, 4)\n",
    "\n",
    "\n",
    "### DATA     \n",
    "n_plot_uncertainty = 100\n",
    "\n",
    "# aeolus track\n",
    "start = datetime(2019, 1, 1, 9)\n",
    "stop = start + timedelta(hours=num_hours)\n",
    "step = timedelta(minutes=step_minutes)\n",
    "train_size = (stop - start) // step \n",
    "with_replacement = True\n",
    "\n",
    "# load \n",
    "X_train, X_test, y_train, y_test = train_test_sets(time, level, start, stop, step, test_size, with_replacement)\n",
    "train_data = VectorDataset(X_train, y_train)\n",
    "test_data = VectorDataset(X_test, y_test)\n",
    "X_uncertainty = sphere_meshgrid(n_plot_uncertainty).reshape(-1, 2)\n",
    "\n",
    "\n",
    "### MODEL\n",
    "# settings \n",
    "kappa = 1.0\n",
    "hidden_variance = total_hidden_variance / max(num_layers - 1, 1)\n",
    "obs_variance = 1.0\n",
    "obs_stddev = obs_variance ** 0.5\n",
    "hidden_kernel = HodgeMaternKernel\n",
    "output_kernel = HodgeMaternKernel\n",
    "\n",
    "experiment_name = f\"{time=}_{level=}_{step_minutes=}_{num_layers=}_{seed=}_{max_ell_variational=}_{num_test_samples=}_{max_ell_prior=}_{num_samples=}_{total_hidden_variance=}_{num_iters=}_{lr=}_{num_hours=}\"\n",
    "print(f\"Running experiment: {experiment_name}\")\n",
    "# build \n",
    "likelihood = VectorGaussian(num_datapoints=train_data.n, obs_stddev=obs_stddev)\n",
    "layers = build_layers(\n",
    "    num_layers=num_layers,\n",
    "    hidden_kernel=hidden_kernel,\n",
    "    output_kernel=output_kernel,\n",
    "    likelihood=likelihood,\n",
    "    hidden_variance=hidden_variance,\n",
    "    kappa=kappa,\n",
    "    max_ell_variational=max_ell_variational,\n",
    "    max_ell_prior=max_ell_prior,\n",
    ")\n",
    "model = ResidualDeepGP(layers=layers, num_samples=num_samples)\n",
    "\n",
    "\n",
    "### FIT\n",
    "# train \n",
    "objective = jax.jit(DeepVectorELBO(negative=True))\n",
    "optim = optax.adam(learning_rate=lr)\n",
    "model_opt, history = fit_deep(\n",
    "    model=model,\n",
    "    objective=objective,\n",
    "    train_data=train_data,\n",
    "    optim=optim,\n",
    "    num_iters=num_iters,\n",
    "    key=train_key,\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "# model_opt, history = model, None\n",
    "\n",
    "\n",
    "# test\n",
    "model_opt = model_opt.replace(num_samples=num_test_samples)\n",
    "mean_test = jnp.mean(moments_deep(test_key, model_opt, X_test)[0], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, X_test_smaller, _, y_test_smaller = train_test_sets(time, level, start, stop, step, 2000, with_replacement)\n",
    "x_test_smaller_lat, x_test_smaller_lon = X_test_smaller[:, 0], X_test_smaller[:, 1]\n",
    "y_test_smaller_dlat, y_test_smaller_dlon = y_test_smaller[:, 0], y_test_smaller[:, 1]\n",
    "\n",
    "x_test_smaller_lat *= 180 / jnp.pi \n",
    "x_test_smaller_lat -= 90\n",
    "x_test_smaller_lon *= 180 / jnp.pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 200\n",
    "X_uncertainty = sphere_meshgrid(n).reshape(-1, 2)\n",
    "_, cov_plot = pathwise_moments_deep(plot_key, model_opt, X_uncertainty)\n",
    "uncertainty = jax.vmap(jax.vmap(jnp.linalg.norm))(cov_plot) # we define uncertainty as the average norm of the covariance matrices in the mixture\n",
    "uncertainty = jnp.mean(uncertainty, axis=0)\n",
    "uncertainty = uncertainty.reshape(n, n)\n",
    "\n",
    "model_opt = model_opt.replace(num_samples=1)\n",
    "sample_test = model_opt.pathwise_sample(plot_key, X_test).squeeze()\n",
    "X_uncertainty = X_uncertainty.reshape(n, n, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, y_train, x_test, y_test, var, var_x, sample, history = X_train, y_train, X_test, y_test, uncertainty, X_uncertainty, sample_test, history\n",
    "var_lat, var_lon = var_x[..., 0], var_x[..., 1]\n",
    "x_train_lat, x_train_lon = x_train[:, 0], x_train[:, 1]\n",
    "y_train_dlat, y_train_dlon = y_train[:, 0], y_train[:, 1]\n",
    "x_test_lat, x_test_lon = x_test[:, 0], x_test[:, 1]\n",
    "y_test_dlat, y_test_dlon = y_test[:, 0], y_test[:, 1]\n",
    "\n",
    "\n",
    "mean = mean_test\n",
    "mean_dlat, mean_dlon = mean[:, 0], mean[:, 1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train_lat *= 180 / jnp.pi\n",
    "x_train_lat -= 90\n",
    "x_train_lon *= 180 / jnp.pi\n",
    "x_test_lat *= 180 / jnp.pi\n",
    "x_test_lat -= 90\n",
    "x_test_lon *= 180 / jnp.pi\n",
    "var_lon *= 180 / jnp.pi\n",
    "var_lat *= 180 / jnp.pi\n",
    "var_lat -= 90\n",
    "\n",
    "\n",
    "aeolus = read_aeolus(start=start, stop=stop, step=step)\n",
    "aeolus_colat = radians_to_angles_colat(aeolus.colat)\n",
    "aeolus_lon = radians_to_angles_lon(aeolus.lon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aeolus = read_aeolus(start=start, stop=stop, step=step)\n",
    "aeolus_colat = radians_to_angles_colat(aeolus.colat)\n",
    "aeolus_lon = radians_to_angles_lon(aeolus.lon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cartopy.crs as ccrs\n",
    "import cartopy.feature as cfeature\n",
    "import matplotlib.colors as mc \n",
    "\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "plt.rcParams['text.usetex'] = True\n",
    "\n",
    "\n",
    "def truncate_colormap(cmap: mc.Colormap, minval=0.0, maxval=1.0, n=100):\n",
    "    new_cmap = mc.LinearSegmentedColormap.from_list(\n",
    "        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),\n",
    "        cmap(np.linspace(minval, maxval, n)))\n",
    "    return new_cmap\n",
    "\n",
    "\n",
    "viridis = truncate_colormap(plt.get_cmap('viridis'), 0.05, 1.0)\n",
    "whites = truncate_colormap(plt.get_cmap('Greys'), 0.05, 0.6)\n",
    "blacks = truncate_colormap(plt.get_cmap('Greys'), 0.8, 1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ground Truth\n",
    "Comparison of 5.5km, 2.0km, and 0.1km "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_level_ground_truth(level: int, n: int = 2000):\n",
    "    _, X, _, y = train_test_sets(time, level, start, stop, step, n, with_replacement)\n",
    "    x_lat, x_lon = X[:, 0], X[:, 1]\n",
    "    y_dlat, y_dlon = y[:, 0], y[:, 1]\n",
    "\n",
    "    x_lat *= 180 / jnp.pi \n",
    "    x_lat -= 90\n",
    "    x_lon *= 180 / jnp.pi\n",
    "    return x_lat, x_lon, y_dlat, y_dlon\n",
    "\n",
    "\n",
    "data = {}\n",
    "for l in [0, 7, 15]:\n",
    "    data[l] = get_level_ground_truth(l, 2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for l, (x_lat, x_lon, y_dlat, y_dlon) in data.items():\n",
    "    projection = ccrs.Robinson()\n",
    "    transform = ccrs.PlateCarree()\n",
    "\n",
    "    fig = plt.figure(figsize=(10, 4))\n",
    "    ax = plt.axes(projection=projection)\n",
    "    ax.stock_img()\n",
    "    ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='black', zorder=1)\n",
    "\n",
    "    q = ax.quiver(\n",
    "        x_lon, \n",
    "        x_lat, \n",
    "        y_dlon, \n",
    "        y_dlat, \n",
    "        angles=\"uv\", \n",
    "        transform=transform, \n",
    "        zorder=2,\n",
    "        width=0.0013,\n",
    "        headwidth=3.5,\n",
    "    )\n",
    "    q._init()\n",
    "    scale = q.scale\n",
    "\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(f\"../experiments/hodge/plots/ground_truth-{l=}.png\", dpi=600, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predictions\n",
    "- mean and standard deviation\n",
    "- two pathwise samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean(x):\n",
    "    m = jnp.mean(moments_deep(test_key, model_opt, x)[0], axis=0)\n",
    "    m_dlat, m_dlon = m[:, 0], m[:, 1]\n",
    "    return m_dlat, m_dlon\n",
    "\n",
    "x = X_test_smaller\n",
    "x_lat, x_lon = x_test_smaller_lat, x_test_smaller_lon\n",
    "y_dlat, y_dlon = get_mean(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax = plt.axes(projection=ccrs.Robinson())\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='lightgrey', zorder=1)\n",
    "\n",
    "\n",
    "c = ax.pcolormesh(\n",
    "    var_lon, \n",
    "    var_lat, \n",
    "    var, \n",
    "    vmin=var.min(), \n",
    "    vmax=var.max(), \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    shading='auto', \n",
    "    edgecolors=None, \n",
    "    cmap=viridis, \n",
    "    zorder=0\n",
    ")\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_lon, \n",
    "    x_lat, \n",
    "    y_dlon, \n",
    "    y_dlat, \n",
    "    angles=\"uv\", \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    zorder=2,\n",
    "    # scale=45,\n",
    "    # increase arrow thickness\n",
    "    width=0.0013,\n",
    "    # decrease arrowhead size\n",
    "    headwidth=3.5,\n",
    ")\n",
    "q._init()\n",
    "scale = q.scale\n",
    "\n",
    "# ax.outline_patch.set_visible(False)\n",
    "# ax.background_patch.set_visible(False)\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig(f\"../experiments/hodge/plots/mean_and_variance-{level=}.png\", dpi=600, bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pathwise_sample(x, key: Key):\n",
    "    s = jnp.squeeze(model_opt.pathwise_sample(key, x)) # [N D]\n",
    "    # s += jax.random.normal(key, s.shape) * model_opt.output_layer.posterior.likelihood.obs_stddev\n",
    "\n",
    "    s_dlat, s_dlon = s[:, 0], s[:, 1]\n",
    "    return s_dlat, s_dlon\n",
    "\n",
    "key = jax.random.key(0)\n",
    "\n",
    "x = X_test_smaller\n",
    "x_lat, x_lon = x_test_smaller_lat, x_test_smaller_lon\n",
    "k1, k2 = jax.random.split(key, 2)\n",
    "y1_dlat, y1_dlon = get_pathwise_sample(x, k1)\n",
    "y2_dlat, y2_dlon = get_pathwise_sample(x, k2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projection = ccrs.Robinson()\n",
    "transform = ccrs.PlateCarree()\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax = plt.axes(projection=projection)\n",
    "ax.stock_img()\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='black', zorder=1)\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_lon, \n",
    "    x_lat, \n",
    "    y1_dlon, \n",
    "    y1_dlat, \n",
    "    angles=\"uv\", \n",
    "    transform=transform, \n",
    "    zorder=2,\n",
    "    width=0.0013,\n",
    "    headwidth=3.5,\n",
    ")\n",
    "q._init()\n",
    "scale = q.scale\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig(f\"../experiments/hodge/plots/pathwise_sample1-{level=}.png\", dpi=600, bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projection = ccrs.Robinson()\n",
    "transform = ccrs.PlateCarree()\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax = plt.axes(projection=projection)\n",
    "ax.stock_img()\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='black', zorder=1)\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_lon, \n",
    "    x_lat, \n",
    "    y2_dlon, \n",
    "    y2_dlat, \n",
    "    angles=\"uv\", \n",
    "    transform=transform, \n",
    "    zorder=2,\n",
    "    width=0.0013,\n",
    "    headwidth=3.5,\n",
    ")\n",
    "q._init()\n",
    "scale = q.scale\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig(f\"../experiments/hodge/plots/pathwise_sample2-{level=}.png\", dpi=600, bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, X_test_smaller, _, y_test_smaller = train_test_sets(time, level, start, stop, step, 2000, with_replacement)\n",
    "x_test_smaller_lat, x_test_smaller_lon = X_test_smaller[:, 0], X_test_smaller[:, 1]\n",
    "y_test_smaller_dlat, y_test_smaller_dlon = y_test_smaller[:, 0], y_test_smaller[:, 1]\n",
    "\n",
    "x_test_smaller_lat *= 180 / jnp.pi \n",
    "x_test_smaller_lat -= 90\n",
    "x_test_smaller_lon *= 180 / jnp.pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean(x):\n",
    "    m = jnp.mean(moments_deep(test_key, model_opt, x)[0], axis=0)\n",
    "    m_dlat, m_dlon = m[:, 0], m[:, 1]\n",
    "    return m_dlat, m_dlon\n",
    "\n",
    "x = X_test_smaller\n",
    "x_lat, x_lon = x_test_smaller_lat, x_test_smaller_lon\n",
    "mean_smaller_dlat, mean_smaller_dlon = get_mean(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the scales\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax = plt.axes(projection=ccrs.Robinson())\n",
    "q = ax.quiver(x_test_lon, x_test_lat, y_test_dlat, y_test_dlon, angles=\"uv\", transform=ccrs.PlateCarree());\n",
    "q._init()\n",
    "scale = q.scale\n",
    "\n",
    "\n",
    "# get the scales\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "ax = plt.axes(projection=ccrs.Robinson())\n",
    "q = ax.quiver(x_test_smaller_lon, x_test_smaller_lat, y_test_smaller_dlat, y_test_smaller_dlon, angles=\"uv\", transform=ccrs.PlateCarree());\n",
    "q._init()\n",
    "scale_smaller = q.scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_dlon = y_test_dlon - mean_dlon\n",
    "diff_dlat = y_test_dlat - mean_dlat\n",
    "error = jnp.linalg.norm(jnp.stack([diff_dlon, diff_dlat], axis=-1), axis=-1)\n",
    "error = (error - error.min()) / (error.max() - error.min()) ** 0.3\n",
    "\n",
    "\n",
    "diff_smaller_dlon = y_test_smaller_dlon - mean_smaller_dlon\n",
    "diff_smaller_dlat = y_test_smaller_dlat - mean_smaller_dlat\n",
    "error_smaller = jnp.linalg.norm(jnp.stack([diff_smaller_dlon, diff_smaller_dlat], axis=-1), axis=-1)\n",
    "error_smaller = (error_smaller - error_smaller.min()) / (error_smaller.max() - error_smaller.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8, 4))\n",
    "ax = plt.axes(projection=ccrs.Robinson())\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='lightgrey', zorder=1)\n",
    "\n",
    "\n",
    "c = ax.pcolormesh(\n",
    "    var_lon, \n",
    "    var_lat, \n",
    "    var, \n",
    "    vmin=var.min(), \n",
    "    vmax=var.max(), \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    shading='auto', \n",
    "    edgecolors=None, \n",
    "    cmap=viridis, \n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_test_lon, \n",
    "    x_test_lat, \n",
    "    diff_dlon, \n",
    "    diff_dlat, \n",
    "    angles=\"uv\", \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    scale=scale, \n",
    "    # color=truncate_colormap(plt.get_cmap('Greys'), 0.65, 1.0)(error),\n",
    "    color='black',\n",
    "    zorder=2,\n",
    "    # linewidth=0.1,\n",
    "    headwidth=3.5,\n",
    "    # headlength=2.0,\n",
    ")\n",
    "q._init()\n",
    "\n",
    "fig.savefig(f\"../experiments/hodge/plots/uncertainty_and_difference-{level=}.png\", dpi=600, bbox_inches='tight', transparent=True)\n",
    "# fig.savefig(\"../experiments/hodge/plots/uncertainty_and_difference-gradient_black_arrows-lightgrey_coastline-narrow_head.png\", dpi=600, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8, 4))\n",
    "ax = plt.axes(projection=ccrs.Robinson())\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75, color='lightgrey', zorder=1)\n",
    "\n",
    "\n",
    "c = ax.pcolormesh(\n",
    "    var_lon, \n",
    "    var_lat, \n",
    "    var, \n",
    "    vmin=var.min(), \n",
    "    vmax=var.max(), \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    shading='auto', \n",
    "    edgecolors=None, \n",
    "    cmap=viridis, \n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_test_smaller_lon, \n",
    "    x_test_smaller_lat, \n",
    "    diff_smaller_dlon, \n",
    "    diff_smaller_dlat, \n",
    "    angles=\"uv\", \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    scale=scale_smaller, \n",
    "    # color=truncate_colormap(plt.get_cmap('Greys'), 0.9, 1.0)(error),\n",
    "    color='black',\n",
    "    zorder=2,\n",
    "    # linewidth=0.1,\n",
    "    headwidth=2.0,\n",
    "    # headlength=2.0,\n",
    ")\n",
    "q._init()\n",
    "\n",
    "fig.savefig(\"../experiments/hodge/plots/uncertainty_and_difference-black_arrows-lightgrey_coastline-narrow_head.png\", dpi=600, bbox_inches='tight')\n",
    "# fig.savefig(\"../experiments/hodge/plots/uncertainty_and_difference-gradient_black_arrows-lightgrey_coastline-narrow_head-2000_arrows.png\", dpi=600, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "1100 / 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time = 7 \n",
    "level = 15 \n",
    "start = datetime(2019, 1, 1, 9)\n",
    "stop = start + timedelta(hours=24)\n",
    "step = timedelta(minutes=1)\n",
    "with_replacement = True\n",
    "\n",
    "_, X_test_smaller, _, y_test_smaller = train_test_sets(time, level, start, stop, step, 1750, with_replacement)\n",
    "x_test_smaller_lat, x_test_smaller_lon = X_test_smaller[:, 0], X_test_smaller[:, 1]\n",
    "y_test_smaller_dlat, y_test_smaller_dlon = y_test_smaller[:, 0], y_test_smaller[:, 1]\n",
    "\n",
    "x_test_smaller_lat *= 180 / jnp.pi \n",
    "x_test_smaller_lat -= 90\n",
    "x_test_smaller_lon *= 180 / jnp.pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(4, 4))\n",
    "ax = plt.axes(projection=ccrs.Orthographic(central_longitude=-100))\n",
    "ax.add_feature(cfeature.COASTLINE, linewidth=0.75)\n",
    "ax.gridlines()\n",
    "ax.stock_img()\n",
    "\n",
    "q = ax.quiver(\n",
    "    x_test_smaller_lon, x_test_smaller_lat, y_test_smaller_dlon, y_test_smaller_dlat, angles=\"uv\", \n",
    "    transform=ccrs.PlateCarree(), \n",
    "    # increase arrow size\n",
    "    scale=45,\n",
    "    # increase arrow thickness\n",
    "    width=0.003,\n",
    "    # decrease arrowhead size\n",
    "    headwidth=3.5,\n",
    "    color=\"black\"\n",
    ")\n",
    "q._init()\n",
    "\n",
    "\n",
    "color = 'indianred'\n",
    "ax.plot(aeolus_lon, aeolus_colat, color=color, alpha=0.45, transform=ccrs.Geodetic())\n",
    "ax.scatter(aeolus_lon, aeolus_colat, color=color, s=5, transform=ccrs.Geodetic())\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"../experiments/hodge/plots/aeolus_track_and_ground_truth.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to render plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_smaller"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# satellite track (red dots)\n",
    "a_inputs = sph_to_car(aeolus[['colat', 'lon']].values)\n",
    "a_outputs = jnp.zeros((a_inputs.shape[0], 1))\n",
    "\n",
    "# ground truth (black arrows) \n",
    "tbx = jax.vmap(tangent_basis)(X_test_smaller)\n",
    "b = jnp.concat([\n",
    "    sph_to_car(X_test_smaller),\n",
    "    jax.vmap(jnp.matmul)(tbx, y_test_smaller)\n",
    "], axis=-1)\n",
    "\n",
    "\n",
    "# save the data as csv using the names of the variables\n",
    "data = [\n",
    "    a_inputs, \n",
    "    a_outputs, \n",
    "    b, \n",
    "]\n",
    "\n",
    "names = [\n",
    "    'a-inputs-half',\n",
    "    'a-outputs-half',\n",
    "    'b-1750', \n",
    "]\n",
    "\n",
    "\n",
    "for datum, name in zip(data, names):\n",
    "    pd.DataFrame(datum).to_csv(f\"../plotting/aeolus_track/{name}.csv\", header=False, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_inputs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
