from jax import config 
config.update("jax_enable_x64", True)


import jax 
import jax.numpy as jnp
import gpjax 
import optax 
from gpjax.distributions import GaussianDistribution
from gpjax.typing import Float
from jax.tree_util import Partial
import tensorflow_probability.substrates.jax as tfp
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Key, Array

from matplotlib import pyplot as plt 

from dataclasses import dataclass, InitVar
from abc import abstractmethod


@jax.jit
def sph_dot_product(sph1: Float, sph2: Float) -> Float:
    """
    Computes dot product in R^3 of two points on the sphere in spherical coordinates.
    """
    colat1, lon1 = sph1[..., 0], sph1[..., 1]
    colat2, lon2 = sph2[..., 0], sph2[..., 1]
    return jnp.sin(colat1) * jnp.sin(colat2) * jnp.cos(lon1 - lon2) + jnp.cos(colat1) * jnp.cos(colat2)


@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))
def gegenbauer(x, max_ell: int, alpha: float = 0.5):
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x
    
    res = jnp.empty((max_ell + 1, *x.shape), dtype=x.dtype)
    res = res.at[0].set(C_0)

    def step(n, res_and_Cs):
        res, C, C_prev = res_and_Cs
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        res = res.at[n].set(C)
        return res, C, C_prev
    
    return jax.lax.cond(
        max_ell == 0,
        lambda: res,
        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[1].set(C_1), C_1, C_0))[0],
    )


@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call 
def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x

    def step(Cs_and_n):
        C, C_prev, n = Cs_and_n
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        return C, C_prev, n + 1

    def cond(Cs_and_n):
        n = Cs_and_n[2]
        return n <= ell

    return jax.lax.cond(
        ell == 0,
        lambda: C_0,
        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],
    )


# NOTE jitting this doesn't help
def sph_gegenbauer(x, y, max_ell: int, alpha: float = 0.5):
    return gegenbauer(x=sph_dot_product(x, y), max_ell=max_ell, alpha=alpha)


# NOTE jitting this doesn't help
def sph_gegenbauer_single(x, y, ell: int, alpha: float = 0.5):
    return gegenbauer_single(x=sph_dot_product(x, y), ell=ell, alpha=alpha)


def array(x):
    return jnp.array(x, dtype=jnp.float64)


@jax.jit
def sph_to_car(sph):
    """
    From spherical (colat, lon) coordinates to cartesian, single point.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    return jnp.stack([colat, lon], axis=-1)


"""
Conversion between hodge and flat coordinates.
"""

@jax.jit
def flatten_matrix(matrix):
    """
    Input matrix has shape (nx, ny, *block_shape).
    """
    out = jnp.vstack([jnp.hstack([block for block in row_blocks]) for row_blocks in matrix])
    return out


@Partial(jax.jit, static_argnames=('spherical',))
def unflatten_matrix(matrix, spherical=True):
    """
    Input matrix has shape (nx, ny).
    """
    dim = 2 if spherical else 3
    out = jnp.array([
        jnp.split(row_block, indices_or_sections=matrix.shape[1]//dim, axis=1)
        for row_block in jnp.split(matrix, indices_or_sections=matrix.shape[0]//dim, axis=0)
    ])
    return out


def flatten_coord(coord):
    """
    Flatten coordinates to 1d array.
    """
    return jnp.ravel(coord)


def unflatten_coord(coord_flat, spherical=True, extra_dims=[]):
    """
    Un-flatten coordinates to (n, 2, *extra_dims).
    """
    return coord_flat.reshape(-1, 2 if spherical else 3, *extra_dims)


from pathlib import Path
from typing import Callable

import numpy as np
from jax import Array


class FundamentalSystemNotPrecomputedError(ValueError):

    def __init__(self, dimension: int):
        message = f"Fundamental system for dimension {dimension} has not been precomputed."
        super().__init__(message)


def fundamental_set_loader(dimension: int, load_dir="fundamental_system") -> Callable[[int], Array]:
    load_dir = Path("../") / load_dir
    file_name = load_dir / f"fs_{dimension}D.npz"

    cache = {}
    if file_name.exists():
        with np.load(file_name) as f:
            cache = {k: v for (k, v) in f.items()}
    else:
        raise FundamentalSystemNotPrecomputedError(dimension)

    def load(degree: int) -> Array:
        key = f"degree_{degree}"
        if key not in cache:
            raise ValueError(f"key: {key} not in cache.")
        return cache[key]

    return load


@dataclass
class SphericalHarmonics(gpjax.Module):
    """
    Spherical harmonics inducing features for sparse inference in Gaussian processes.

    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,
    as such, they form an orthogonal basis.

    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute
    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via
    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during
    initialisation.

    Attributes:
        num_frequencies: The number of frequencies, up to which, we compute the harmonics.

    Returns:
        An instance of the spherical harmonics features.
    """

    max_ell: int = gpjax.base.static_field()
    sphere_dim: int = gpjax.base.static_field()
    alpha: float = gpjax.base.static_field(init=False)
    orth_basis: Array = gpjax.base.param_field(init=False, trainable=False)
    Vs: list[Array] = gpjax.base.param_field(init=False, trainable=False)

    @property
    def levels(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)
    
    def __post_init__(self) -> None:
        """
        Initialise the parameters of the spherical harmonic features and return a `Param` object.

        Returns:
            None
        """
        dim = self.sphere_dim + 1

        # Try loading a pre-computed fundamental set.
        fund_set = fundamental_set_loader(dim)

        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.
        self.alpha = (dim - 2.0) / 2.0

        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.
        self.Vs = [fund_set(n) for n in self.levels]

        # pre-compute and save the orthogonal basis 
        self.orth_basis = self._orthogonalise_basis()


    @property
    def Ls(self) -> list[Array]:
        """
        Alias for the orthogonal basis at every frequency.
        """
        return self.orth_basis

    @property
    def num_phase_in_frequency(self) -> list[int]:
        """
        Get the total number of phases/harmonics at every frequency.

        Returns:
            A list with the number of phases per frequency.
        """
        return jax.tree.map(lambda x: x.shape[0], self.Vs)

    @property
    def num_inducing(self) -> int:
        """
        Computes the total number of inducing features, as the sum of all phases.

        Args:
            param: A `Param` initialised with the spherical harmonic features.

        Returns:
            The total number of inducing features.
        """
        return sum(self.num_phase_in_frequency)

    def _orthogonalise_basis(self) -> None:
        """
        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.
        """
        alpha = self.alpha
        levels = jnp.split(self.levels, self.max_ell + 1)
        const = alpha / (alpha + self.levels.astype(jnp.float64))
        const = jnp.split(const, self.max_ell + 1)

        def _func(v, n, c):
            x = jnp.matmul(v, v.T)
            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)
            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))
            return L

        return jax.tree.map(_func, self.Vs, levels, const)

    def custom_gegenbauer_single(self, x, ell, alpha):
        return gegenbauer(x, self.max_ell, alpha)[ell]

    @jax.jit
    def polynomial_expansion(self, X: Float[Array, "N D"]) -> Float[Array, "M N"]:
        """
        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.

        Args:
            X: Input Array.

        Returns:
            The harmonics evaluated at the input as a polynomial expansion of the basis.
        """
        levels = jnp.split(self.levels, self.max_ell + 1)

        def _func(v, n, L):
            vxT = jnp.dot(v, X.T)
            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)
            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)
            return harmonic

        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)
        return jnp.concatenate(harmonics, axis=0)
    
    def __eq__(self, other: "SphericalHarmonics") -> bool:
        """
        Check if two spherical harmonic features are equal.

        Args:
            other: The other spherical harmonic features.

        Returns:
            A boolean indicating if the two features are equal.
        """
        # Given the first two parameters, the rest are deterministic. 
        # The user must not mutate all other fields, but that is not enforced for now.
        return (
            self.max_ell == other.max_ell 
            and self.sphere_dim == other.sphere_dim 
        )
    


import warnings 
from typing import Optional
from jaxtyping import Num
from gpjax.typing import ScalarFloat


def _check_precision(
    X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
    r"""Checks the precision of $`X`$ and $`y`."""
    if X is not None and X.dtype != jnp.float64:
        warnings.warn(
            "X is not of type float64. "
            f"Got X.dtype={X.dtype}. This may lead to numerical instability. ",
            stacklevel=2,
        )

    if y is not None and y.dtype != jnp.float64:
        warnings.warn(
            "y is not of type float64."
            f"Got y.dtype={y.dtype}. This may lead to numerical instability.",
            stacklevel=2,
        )


@dataclass 
class VectorDataset(gpjax.Dataset):
    X: Optional[Num[Array, "N D"]] = None
    y: Optional[Num[Array, "M Q"]] = None

    def __post_init__(self) -> None:
        r"""Checks that the shapes of $`X`$ and $`y`$ are compatible,
        and provides warnings regarding the precision of $`X`$ and $`y`$."""
        # _check_shape(self.X, self.y)
        _check_precision(self.X, self.y)


@dataclass 
class VectorZeroMean(gpjax.mean_functions.AbstractMeanFunction):
    dim: int = gpjax.base.static_field(1)

    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N E"]:
        return jnp.zeros((x.shape[0] * self.dim))


@dataclass
class AnalyticalVectorGaussianIntegrator(gpjax.integrators.AbstractIntegrator):
    r"""Compute the analytical integral of a Gaussian likelihood.

    When the likelihood function is Gaussian, the integral can be computed in closed
    form. For a Gaussian likelihood $`p(y|f) = \mathcal{N}(y|f, \sigma^2)`$ and a
    variational distribution $`q(f) = \mathcal{N}(f|m, s)`$, the expected
    log-likelihood is given by
    ```math
    \mathbb{E}_{q(f)}[\log p(y|f)] = -\frac{1}{2}\left(\log(2\pi\sigma^2) + \frac{1}{\sigma^2}((y-m)^2 + s)\right)
    ```
    """

    def integrate(
        self,
        fun: Callable,
        y: Float[Array, "N D"],
        mean: Float[Array, "N D"],
        covariance: Float[Array, "N D D"],
        likelihood: gpjax.likelihoods.Gaussian,
    ) -> Float[Array, " N"]:
        r"""Compute a Gaussian integral.

        Args:
            fun (Callable): The Gaussian likelihood to be integrated.
            y (Float[Array, 'N D']): The observed response variable.
            mean (Float[Array, 'N D']): The mean of the variational distribution.
            covariance (Float[Array, 'N D D']): The block diagonal covariance of the variational
                distribution.
            likelihood (Gaussian): The Gaussian likelihood function.

        Returns:
            Float[Array, 'N']: The expected log likelihood.
        """
        d = y.shape[-1]
        obs_var = likelihood.obs_stddev.squeeze() ** 2 # [1]
        sq_error = jnp.sum(jnp.square(y - mean), axis=-1) # [N]
        log2pi = jnp.log(2.0 * jnp.pi) # [1]
        # jax.debug.print(f"{covariance.shape=}, {jnp.trace(covariance, axis1=1, axis2=2).shape=}")
        val = (
            d * (log2pi + jnp.log(obs_var)) # [1]
            + (sq_error + jnp.trace(covariance, axis1=1, axis2=2)) / obs_var # ([N] + [N]) / [1] -> [N]
        )
        return -0.5 * val


@dataclass 
class VectorGaussian(gpjax.likelihoods.Gaussian):
    integrator: gpjax.integrators.AbstractIntegrator = gpjax.base.static_field(AnalyticalVectorGaussianIntegrator())


import abc 
import cola
from typing import TypeVar 
from jaxtyping import Float, Num
from gpjax.typing import ScalarFloat


Kernel = TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel")


@dataclass
class AbstractVectorKernelComputation:
    r"""Abstract class for vector kernel computations."""

    def gram(
        self,
        kernel: Kernel,
        x: Num[Array, "N D"],
    ) -> cola.ops.LinearOperator:
        r"""Compute Gram covariance operator of the kernel function.

        Args:
            kernel (AbstractKernel): the kernel function.
            x (Num[Array, "N N"]): The inputs to the kernel function.

        Returns
        -------
            LinearOperator: Gram covariance operator of the kernel function.
        """
        Kxx = self.cross_covariance(kernel, x, x)
        return cola.PSD(cola.ops.Dense(Kxx))

    @abc.abstractmethod
    def cross_covariance(
        self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
    ) -> Float[Array, "N M"]:
        r"""For a given kernel, compute the NxM gram matrix on an a pair
        of input matrices with shape NxD and MxD.

        Args:
            kernel (AbstractKernel): the kernel function.
            x (Num[Array,"N D"]): The first input matrix.
            y (Num[Array,"M D"]): The second input matrix.

        Returns
        -------
            Float[Array, "N M"]: The computed cross-covariance.
        """
        raise NotImplementedError

    def diagonal(self, kernel: Kernel, inputs: Num[Array, "N D"]) -> cola.ops.BlockDiag:
        r"""For a given kernel, compute the elementwise diagonal of the
        NxN gram matrix on an input matrix of shape NxD.

        Args:
            kernel (AbstractKernel): the kernel function.
            inputs (Float[Array, "N D"]): The input matrix.

        Returns
        -------
            Diagonal: The computed diagonal variance entries.
        """
        return cola.PSD(cola.ops.BlockDiag(diag=jax.vmap(lambda x: kernel(x, x))(inputs)))
    

class DenseVectorKernelComputation(AbstractVectorKernelComputation):
    r"""Dense kernel computation class. Operations with the kernel assume
    a dense gram matrix structure.
    """

    def cross_covariance(
        self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
    ) -> Float[Array, "2N 2M"]:
        r"""Compute the cross-covariance matrix.

        For a given kernel, compute the NxM covariance matrix on a pair of input
        matrices of shape $`NxD`$ and $`MxD`$.

        Args:
            kernel (Kernel): the kernel function.
            x (Float[Array,"N D"]): The input matrix.
            y (Float[Array,"M D"]): The input matrix.

        Returns
        -------
            Float[Array, "2N 2M"]: The computed cross-covariance.
        """
        cross_cov = jax.vmap(lambda x: jax.vmap(lambda y: kernel(x, y))(y))(x)
        # flatten for consistency
        return flatten_matrix(cross_cov)


@dataclass 
class VectorZeroMean(gpjax.mean_functions.AbstractMeanFunction):
    space_dim: int = gpjax.base.static_field(2)
    output_dim: int = gpjax.base.static_field(1)

    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N E"]:
        return jnp.zeros((x.shape[0] * self.space_dim, self.output_dim))


@dataclass
class AbstractVectorKernel(gpjax.kernels.AbstractKernel):
    r"""Base vector kernel class."""

    compute_engine: AbstractVectorKernelComputation = gpjax.base.static_field(DenseVectorKernelComputation())
    active_dims: Optional[list[int]] = gpjax.base.static_field(None)
    name: str = gpjax.base.static_field("AbstractVectorKernel")
    

@jax.jit
def tangent_basis_normalization_matrix(x: Float[Array, "2"]) -> Float[Array, "2 2"]:
    return jnp.array([
        [1.0, 0.0], 
        [0.0, 1.0 / jnp.sin(x[0])],
    ])


hodge_star_matrix = jnp.array([
    [0.0, 1.0],
    [-1.0, 0.0],
])


@Partial(jax.jit, static_argnames=('min_value', ))
def _ensure_colatitude_nonzero(x: Float[Array, "N 2"], min_value: float) -> Float[Array, "N 2"]:
    return x.at[..., 0].set(jnp.where(x[..., 0] == 0, min_value, x[..., 0]))


@jax.jit
def matern_spectral_density(ell: ScalarFloat, kappa: ScalarFloat, nu: ScalarFloat, variance: ScalarFloat) -> ScalarFloat:
    lambda_ells = ell * (ell + 1)
    
    # Compute log of Phi_nu_ells to avoid underflow
    log_Phi_nu_ells = -(nu + 1) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))
    
    # Subtract max value for numerical stability
    max_log_Phi = jnp.max(log_Phi_nu_ells)
    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)
    
    # Normalize the density, so that it sums to 1
    num_harmonics_per_ell = 2 * ell + 1
    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)
    return variance * Phi_nu_ells / normalizer


@dataclass 
class AbstractHodgeKernel(AbstractVectorKernel):
    nu: ScalarFloat = gpjax.param_field(jnp.array(2.5), bijector=tfp.bijectors.Softplus())
    kappa: ScalarFloat = gpjax.param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())
    variance: ScalarFloat = gpjax.param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())
    alpha: float = gpjax.base.static_field(0.5)
    max_ell: int = gpjax.base.static_field(10)
    colatitude_min_value: float = gpjax.base.static_field(1e-12) # NOTE not sure what exact value to use here
    spherical_harmonic_fields: "AbstractSphericalHarmonicFields" = gpjax.base.static_field(None)

    @property
    def ells(self) -> Float[Array, ""]:
        return jnp.arange(1, self.max_ell + 1)
    
    def spectral_density(self) -> ScalarFloat:
        return matern_spectral_density(self.ells, self.kappa, self.nu, self.variance)
    
    @jax.jit
    def weighted_gegenbauer(self, x: Float, y: Float, weights: Float) -> Float:
        lambda_ells = self.ells * (self.ells + 1)
        values = sph_gegenbauer(x, y, self.max_ell, self.alpha)[1:]
        return weights * values / lambda_ells
    
    @jax.jit
    def dd_weighted_gegenbauer(self, x: Float[Array, "2"], y: Float[Array, "2"], weights: Float) -> Float[Array, "2 2"]:
        return jax.jacfwd(jax.jacfwd(lambda x, y: self.weighted_gegenbauer(x, y, weights), argnums=0), argnums=1)(x, y)
    
    @jax.jit
    def validate_inputs(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> tuple[Float[Array, "2"], Float[Array, "2"]]:
        x = _ensure_colatitude_nonzero(x, self.colatitude_min_value)
        y = _ensure_colatitude_nonzero(y, self.colatitude_min_value)
        return x, y

    def _pathwise_sample_from_weights(self, x: Float[Array, "S N 2"], w: Float[Array, "I"]) -> Float[Array, "S N 2"]:
        Phi_x = jax.vmap(self.spherical_harmonic_fields)(x) # [S N I 2]
        ahats_per_frequency = self.spectral_density() # [I]
        ahats_per_phase = jnp.repeat(
            ahats_per_frequency, 
            self.spherical_harmonic_fields.num_phases_per_frequency,
            total_repeat_length=self.spherical_harmonic_fields.num_phases
        ) # [I]
        tilde_Phi_x = jnp.einsum('snid, i -> snid', Phi_x, jnp.sqrt(ahats_per_phase)) # [N I 2]
        return jnp.einsum('snid, si -> snd', tilde_Phi_x, w)
    
    @Partial(jax.jit, static_argnames=('num_samples',))
    def pathwise_sample_from_weights(
        self, 
        x: Float[Array, "N 2"] | Float[Array, "S N 2"],
        w: Float[Array, "I"], 
        num_samples: int = 1
    ) -> Float[Array, "N 2"]:
        x_shape = jnp.broadcast_shapes(x.shape, (num_samples, 1, 1))
        x = jnp.broadcast_to(x, x_shape)
        return self._pathwise_sample_from_weights(x, w)
    
    @Partial(jax.jit, static_argnames=('num_samples',))
    def sample_weights(self, key: Key, num_samples: int = 1) -> Float[Array, "I"]:
        return jax.random.normal(key, shape=(num_samples, self.spherical_harmonic_fields.num_phases), dtype=jnp.float64)
    
    def pathwise_sample(self, key: Key, x: Float[Array, "N 2"], num_samples: int = 1) -> Float[Array, "N 2"]:
        w = self.sample_weights(key, num_samples)
        return self.pathwise_sample_from_weights(x, w, num_samples)


@dataclass
class HodgeMaternCurlFreeKernel(AbstractHodgeKernel):

    def __post_init__(self):
        try: 
            self.spherical_harmonic_fields = CurlFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)
        except FundamentalSystemNotPrecomputedError as e:
            warnings.warn(
                f"{e}",
                f"Pathwise sampling will not be available unless max_ell is sufficiently reduced."
            )

    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        x, y = self.validate_inputs(x, y)
        weights = self.spectral_density() * (2 * self.ells + 1)
        dd = jnp.sum(self.dd_weighted_gegenbauer(x, y, weights=weights), axis=0)

        Nx, Ny = tangent_basis_normalization_matrix(x), tangent_basis_normalization_matrix(y)
        return Nx.T @ dd @ Ny


@dataclass
class HodgeMaternDivFreeKernel(AbstractHodgeKernel):

    def __post_init__(self):
        try: 
            self.spherical_harmonic_fields = DivFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)
        except FundamentalSystemNotPrecomputedError as e:
            warnings.warn(
                f"{e}",
                f"Pathwise sampling will not be available unless max_ell is sufficiently reduced."
            )

    @jax.jit
    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        x, y = self.validate_inputs(x, y)
        weights = self.spectral_density() * (2 * self.ells + 1)
        dd = jnp.sum(self.dd_weighted_gegenbauer(x, y, weights=weights), axis=0)

        Nx, Ny = tangent_basis_normalization_matrix(x), tangent_basis_normalization_matrix(y)
        H = hodge_star_matrix
        return H.T @ Nx.T @ dd @ Ny @ H


@dataclass
class HodgeMaternKernel(AbstractVectorKernel):
    kappa: InitVar[ScalarFloat] = 1.0
    nu: InitVar[ScalarFloat] = 2.5
    variance: InitVar[ScalarFloat] = 1.0
    colatitude_min_value: InitVar[ScalarFloat] = 1e-12

    max_ell: int = gpjax.base.static_field(10)
    curl_free_kernel: HodgeMaternCurlFreeKernel = gpjax.param_field(init=False)
    div_free_kernel: HodgeMaternDivFreeKernel = gpjax.param_field(init=False)

    def __post_init__(self, kappa, nu, variance, colatitude_min_value):
        self.curl_free_kernel = HodgeMaternCurlFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=colatitude_min_value)
        self.div_free_kernel = HodgeMaternDivFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=colatitude_min_value)

    def spectral_density(self):
        return jnp.concat([self.curl_free_kernel.spectral_density(), self.div_free_kernel.spectral_density()])
    
    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        return self.curl_free_kernel(x, y) + self.div_free_kernel(x, y)
    
    def sample_weights(self, key: Key, num_samples: int = 1) -> Float[Array, "I"]:
        return jnp.concatenate([
            self.curl_free_kernel.sample_weights(key, num_samples), 
            self.div_free_kernel.sample_weights(key, num_samples),
        ], axis=-1)

    def pathwise_sample_from_weights(self, x: Float[Array, "N 2"], w: Float[Array, "I"], num_samples: int = 1) -> Float[Array, "N 2"]:
        curl_free_w, div_free_w = jnp.split(w, 2, axis=-1)
        curl_free_sample = self.curl_free_kernel.pathwise_sample_from_weights(x, curl_free_w, num_samples)
        div_free_sample = self.div_free_kernel.pathwise_sample_from_weights(x, div_free_w, num_samples)
        return curl_free_sample + div_free_sample
    
    def pathwise_sample(self, key: Key, x: Float[Array, "N 2"], num_samples: int = 1) -> Float[Array, "N 2"]:
        """
        Args:
            x:  The input locations. Can be [N 2] or [S N 2]. Internally, it is broadcasted to [S N 2], then
                flattened, then processed, and then reshaped back to [S N 2].
        
        """
        w = self.sample_weights(key, num_samples) # [S I]
        return self.pathwise_sample_from_weights(x, w, num_samples)


@dataclass 
class AbstractSphericalHarmonicFields(gpjax.Module):
    max_ell: int = gpjax.base.static_field(10)
    sphere_dim: int = gpjax.base.static_field(2)
    _colatitude_min_value: float = gpjax.base.static_field(1e-12)
    spherical_harmonics: SphericalHarmonics = gpjax.base.static_field(init=False)
    num_phases_per_frequency: Float[Array, "L"] = gpjax.base.param_field(init=False, trainable=False)
    num_phases: int = gpjax.base.static_field(init=False)
    num_fields: int = gpjax.base.static_field(init=False)

    def __post_init__(self) -> None:
        self.spherical_harmonics = SphericalHarmonics(max_ell=self.max_ell, sphere_dim=self.sphere_dim)
        num_phases_per_frequency = self.spherical_harmonics.num_phase_in_frequency[1:]
        self.num_phases_per_frequency = jnp.array(num_phases_per_frequency)
        self.num_phases = sum(num_phases_per_frequency)

    @jax.jit
    def _sph_polynomial_expansion(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2"]:
        ells = jnp.arange(1, self.max_ell + 1)
        lambda_ells = ells * (ells + 1)
        normalization_factor = jnp.repeat(
            jnp.sqrt(lambda_ells),
            self.num_phases_per_frequency,
            total_repeat_length=self.num_phases
        )
        return self.spherical_harmonics.polynomial_expansion(sph_to_car(x))[1:] / normalization_factor

    @jax.jit
    def _field_polynomial_expansion_single(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2"]:
        Nx = tangent_basis_normalization_matrix(x)
        return jax.jacfwd(self._sph_polynomial_expansion)(x) @ Nx
    
    @jax.jit
    def _field_polynomial_expansion(self, x: Float[Array, "N 2"]) -> Float[Array, "N I 2"]:
        x = _ensure_colatitude_nonzero(x, self._colatitude_min_value)
        return jax.vmap(self._field_polynomial_expansion_single)(x)

    @abstractmethod
    def __call__(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2"]:
        pass 
    
    def __eq__(self, other: "AbstractSphericalHarmonicFields") -> bool:
        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim and self._colatitude_min_value == other._colatitude_min_value


@dataclass 
class CurlFreeSphericalHarmonicFields(AbstractSphericalHarmonicFields):

    def __post_init__(self) -> None:
        super().__post_init__()
        self.num_fields = self.num_phases

    @jax.jit
    def __call__(self, x: Float[Array, "N 2"]) -> Float[Array, "N I 2"]:
        return self._field_polynomial_expansion(x)
    
    def __eq__(self, other: "CurlFreeSphericalHarmonicFields") -> bool:
        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim
    

@dataclass 
class DivFreeSphericalHarmonicFields(AbstractSphericalHarmonicFields):

    def __post_init__(self) -> None:
        super().__post_init__()
        self.num_fields = self.num_phases

    @jax.jit
    def __call__(self, x: Float[Array, "N 2"]) -> Float[Array, "N I 2"]:
        H = hodge_star_matrix
        return self._field_polynomial_expansion(x) @ H
    
    def __eq__(self, other: "DivFreeSphericalHarmonicFields") -> bool:
        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim
    

@dataclass 
class SphericalHarmonicFields(AbstractSphericalHarmonicFields):

    def __post_init__(self) -> None:
        super().__post_init__()
        self.num_fields = 2 * self.num_phases

    @jax.jit
    def __call__(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2I 2"]:
        """
        Returns curl-free and divergence-free fields concatenated.
        """
        H = hodge_star_matrix
        v = self._field_polynomial_expansion(x) # [N I 2]
        return jnp.concat([v, v @ H], axis=-2)
    
    def __eq__(self, other: "SphericalHarmonicFields") -> bool:
        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim
    

@dataclass 
class AbstractVectorSHF(gpjax.variational_families.AbstractVariationalFamily):
    r"""The orthonormal generalized variational family of probability distributions.

    The variational family is $`q(f(\cdot)) = \int p(f(\cdot)\mid u) q(u) \mathrm{d}u`$, where
    $`u = f(z)`$ are the function values at the inducing inputs $`z`$
    and the distribution over the inducing inputs is
    $`q(u) = \mathcal{N}(\mu, S)`$.  We parameterise this over
    $`\mu`$ and $`sqrt`$ with $`S = sqrt sqrt^{\top}`$.
    """
    max_ell: int = gpjax.base.static_field(1)
    jitter: ScalarFloat = gpjax.base.static_field(1e-6)
    variational_mean: Float[Array, "N 1"] | None = gpjax.base.param_field(None)
    variational_root_covariance: Float[Array, "N N"] = gpjax.base.param_field(
        None, bijector=tfp.bijectors.FillTriangular()
    )
    spherical_harmonic_fields: AbstractSphericalHarmonicFields = gpjax.base.static_field(init=False)
    sphere_dim: int = gpjax.base.static_field(2)
    num_inducing: int = gpjax.base.static_field(init=False)

    def __post_init__(self) -> None:
        self.num_inducing = self.spherical_harmonic_fields.num_fields
        # Kzz and muz does not change during optimization
        self.muz = jnp.zeros((self.num_inducing, 1))

        if self.variational_mean is None:
            self.variational_mean = jnp.zeros((self.num_inducing, 1))        

        if self.variational_root_covariance is None:
            self.variational_root_covariance = jnp.eye(self.num_inducing) + self.jitter

    def _repeat_per_phase(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2 I"]:
        return jnp.repeat(
            x, 
            self.spherical_harmonic_fields.num_phases_per_frequency,
            total_repeat_length=self.spherical_harmonic_fields.num_phases,
        )

    @abstractmethod
    def ahats(self) -> Float[Array, "I"]:
        pass  

    @jax.jit
    def Lz_T_inv_diagonal(self):
        ahats = self.ahats()
        return jnp.sqrt(ahats / (1 + ahats * self.jitter))


    def Kzt(self, t: Float[Array, "N 2"]) -> Float[Array, "N I 2"]:
        r"""Compute the cross-covariance between the inducing inputs and the test inputs.

        Args:
            t (Float[Array, "N 2"]): The test inputs.

        Returns
        -------
            Float[Array, "N (2 max_ell + 1)"]: The cross-covariance between the inducing inputs and the test inputs.
        """
        fields = self.spherical_harmonic_fields(t) # [N 2 I]
        return jnp.permute_dims(fields, (0, 1, 2)).reshape(self.num_inducing, -1)
    
    def prior_kl(self) -> ScalarFloat:
        # Unpack variational parameters
        mu = self.variational_mean
        sqrt = self.variational_root_covariance
        sqrt = cola.ops.Triangular(sqrt)

        # Unpack mean function and kernel
        muz = self.muz # TODO maybe allow non-zero prior mean. This would necessitate setting the first position of the mean to the prior mean constant

        S = sqrt @ sqrt.T

        qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
        pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()))

        return qu.kl_divergence(pu) # TODO efficiency here can be improved by using the fact that Kzz_jittered is diagonal 
    
    def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
        t = test_inputs

        # Unpack variational parameters
        mu = self.variational_mean
        sqrt = self.variational_root_covariance # [I I]

        # Unpack mean function and kernel
        mean_function = self.posterior.prior.mean_function
        kernel = self.posterior.prior.kernel

        # Compute posterior covariance
        Ktt = kernel.gram(t) # [2N 2N]
        Ktz = self.Kzt(t).mT # [I 2N]
        Lz_T_inv = self.Lz_T_inv_diagonal()

        Ktz_Lz_T_inv = Ktz * Lz_T_inv
        Ktz_Lz_T_inv_sqrt = Ktz_Lz_T_inv @ sqrt # [2N I] @ [I I] -> [2N I]
        covariance = (
            Ktt 
            + Ktz_Lz_T_inv_sqrt @ Ktz_Lz_T_inv_sqrt.mT
            - Ktz_Lz_T_inv @ Ktz_Lz_T_inv.mT
        )
        covariance = cola.PSD(covariance + cola.ops.I_like(covariance) * self.jitter) # add jitter for spectral stability

        # Compute posterior mean 
        mut = mean_function(t)
        muz = self.muz

        mean = (
            mut 
            + Ktz_Lz_T_inv @ (mu - muz) # [2N I] @ [I 1] -> [2N 1]
        )

        return GaussianDistribution(
            loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
        )
    
    @Partial(jax.jit, static_argnames=('num_samples',))
    def _pathwise_sample(self, key: Key, test_inputs: Float[Array, "S N 2"], num_samples: int) -> Float[Array, "S N 2"]:
        Ktt_key, S_key = jax.random.split(key)

        t = test_inputs # [S N 2]

        # Unpack variational parameters
        m = self.variational_mean
        m = jnp.squeeze(m, axis=-1)
        sqrt = self.variational_root_covariance # [I I]

        # Unpack mean function and kernel
        kernel = self.posterior.prior.kernel

        # Compute posterior covariance
        w = kernel.sample_weights(Ktt_key, num_samples) # [S I]
        Ktt_sample = kernel.pathwise_sample_from_weights(t, w, num_samples) # [S N 2]

        Phi_t = jax.vmap(self.spherical_harmonic_fields)(t) # [S N I 2]
        Lz_T_inv = self.Lz_T_inv_diagonal() # [I]
        tilde_Phi_t = jnp.einsum('snid, i -> snid', Phi_t, Lz_T_inv) # [S N I 2]

        S_sample = jax.random.multivariate_normal(
            S_key, mean=m, cov=sqrt @ sqrt.T, shape=(num_samples,)
        ) # [S I]

        covariance_sample = (
            Ktt_sample 
            + jnp.einsum('snid, si -> snd', tilde_Phi_t, S_sample - w)
        )

        # Compute posterior mean 
        mean = jnp.einsum('snid, i -> snd', tilde_Phi_t, m)

        return mean + covariance_sample
    

    def pathwise_sample(self, key: Key, test_inputs: Float[Array, "N D"], num_samples: int) -> Float[Array, "S N D"]:
        """
        Args:   
            key: The random key.
            test_inputs: The input locations. Can be [N D] or [S N D]. Internally, it is broadcasted to [S N D] at the beginning, 
            then flattened, then processed, and then reshaped back to [S N D].
            num_samples: The number of samples to draw.
        """
        test_inputs_shape = jnp.broadcast_shapes(test_inputs.shape, (num_samples, 1, 1))
        test_inputs = jnp.broadcast_to(test_inputs, test_inputs_shape)
        return self._pathwise_sample(key, test_inputs, num_samples)



class CurlFreeVectorSHF(AbstractVectorSHF):
        
    def __post_init__(self) -> None:
        self.spherical_harmonic_fields = CurlFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)
        super().__post_init__()

    def ahats(self):
        ahats_per_frequency = self.posterior.prior.kernel.spectral_density()[:self.max_ell]
        return self._repeat_per_phase(ahats_per_frequency)
    

@dataclass 
class DivFreeVectorSHF(AbstractVectorSHF):

    def __post_init__(self) -> None:
        self.spherical_harmonic_fields = DivFreeSphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)
        super().__post_init__()

    def ahats(self):
        ahats_per_frequency = self.posterior.prior.kernel.spectral_density()[:self.max_ell]
        return self._repeat_per_phase(ahats_per_frequency)
    

@dataclass 
class VectorSHF(AbstractVectorSHF):

    def __post_init__(self) -> None:
        self.spherical_harmonic_fields = SphericalHarmonicFields(max_ell=self.max_ell, sphere_dim=2)
        super().__post_init__()

    def ahats(self):
        curl_free_kernel = self.posterior.prior.kernel.curl_free_kernel
        div_free_kernel = self.posterior.prior.kernel.div_free_kernel

        curl_free_ahats_per_frequency = curl_free_kernel.spectral_density()[:self.max_ell]
        div_free_ahats_per_frequency = div_free_kernel.spectral_density()[:self.max_ell]

        return jnp.concatenate([
            self._repeat_per_phase(curl_free_ahats_per_frequency),
            self._repeat_per_phase(div_free_ahats_per_frequency),
        ]) 
    

def variational_family_from_kernel(kernel: type[AbstractVectorKernel]) -> type[AbstractVectorSHF]:
    if issubclass(kernel, HodgeMaternCurlFreeKernel):
        return CurlFreeVectorSHF
    elif issubclass(kernel, HodgeMaternDivFreeKernel):
        return DivFreeVectorSHF
    elif issubclass(kernel, HodgeMaternKernel):
        return VectorSHF
    else:
        raise ValueError("Unknown kernel type.")
    

from jaxtyping import Key 

# TODO Should consider double jax.vmap without reshaping and using the batched functionality of MultivariateNormalFullCovariance
@jax.jit
def sample_from_marginal(
    key: Key, 
    model: gpjax.gps.AbstractPrior | gpjax.variational_families.AbstractVariationalGaussian,
    x: Float[Array, "N D"] | Float[Array, "S N D"],
) -> Float[Array, "S N O"]:
    """
    Sample from the marginal distribution of the model at the input locations.

    Args:
        key: The random key.
        model: The model object.
        x: The input locations. Can be [N D] or [S N D]. Internally, it is broadcasted to [S N D] at the beginning.
    """

    def moments(t: Float[Array, "D"]) -> tuple[Float[Array, "O"], Float[Array, "O O"]]:
        pt = model(t)
        return pt.mean(), pt.covariance()
    
    means, covariance_matrices = jax.vmap(jax.vmap(moments))(x[:, :, None]) # [S N O], [S N O O]

    # NOTE we should probably add expand to num_samples here
    marginal_pt = tfp.distributions.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariance_matrices)
    return marginal_pt.sample(seed=key, sample_shape=())


EPS = 1e-12


@jax.jit
def tangent_basis(x: Float[Array, "3"]) -> Float[Array, "3"]:
    tb = jax.jacfwd(sph_to_car)(x)
    tb /= jnp.linalg.norm(tb, axis=0, keepdims=True)
    return tb 

@jax.jit
def expmap_car(x: Float[Array, "3"], v: Float[Array, "3"]) -> Float[Array, "3"]:
    def first_order_taylor():
        t = x + v 
        return t / jnp.linalg.norm(t)

    theta = jnp.linalg.norm(v)
    return jax.lax.cond(
        theta < EPS,
        first_order_taylor,
        lambda: jnp.cos(theta) * x + jnp.sin(theta) * v / theta,
    )


@Partial(jax.jit, static_argnames=("colatitude_min_value", ))
def expmap_sph(x: Float[Array, "D"], v: Float[Array, "D"], colatitude_min_value: float = EPS) -> Float[Array, "D"]:
    """
    Exponential map on the sphere taking x in spherical coordinates and v in the 'canonical' coordinate frame. 
    This function internally ensures that the colatitude of x is not too small to avoid nans.
    """
    x = _ensure_colatitude_nonzero(x, colatitude_min_value)
    x_prime = sph_to_car(x)
    v_prime = tangent_basis(x) @ v
    return car_to_sph(expmap_car(x_prime, v_prime))


from dataclasses import InitVar


@dataclass 
class IdentityPosterior(gpjax.gps.AbstractPosterior):
    likelihood: None = gpjax.base.static_field(None)

    def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
        return self.prior(test_inputs)
    

@dataclass 
class AbstractDeepGP(gpjax.Module):
    layers: list[gpjax.variational_families.AbstractVariationalGaussian] = gpjax.base.param_field(init=True)
    num_samples: int = gpjax.base.static_field(10)
    num_layers: int = gpjax.base.static_field(init=False)

    def __post_init__(self) -> None:
        self.num_layers = len(self.layers)

    @property 
    def hidden_layers(self) -> list[gpjax.variational_families.AbstractVariationalGaussian]:
        return self.layers[:-1]
    
    @property
    def output_layer(self) -> gpjax.variational_families.AbstractVariationalGaussian:
        return self.layers[-1]

    def prior_kl(self) -> ScalarFloat:
        return sum(layer.prior_kl() for layer in self.layers)
    
    @abstractmethod
    def sample_from_hidden(self, key: Key, x: Float[Array, "N D"]) -> Float[Array, "S N D"]:
        pass

    def output_predict(self, x: Float[Array, "S N D"]) -> tfd.MixtureSameFamily:
        """
        Predict through the output layer. 

        Args:
            x (Float[Array, "S N D"]): The input data. 
        """
        def moments(t: Float[Array, "N D"]) -> tuple[Float[Array, "N"], Float[Array, "N N"]]:
            pt = self.output_layer(t)
            return pt.mean(), pt.covariance()
        
        means, covariance_matrices = jax.vmap(moments)(x)
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)),
            components_distribution=tfd.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariance_matrices),
        )
    
    def predict(self, key: Key, x: Float[Array, "N D"]) -> tfd.MixtureSameFamily:
        """
        Predict through the entire model. 

        Args:
            x (Float[Array, "N D"]): The input data. 
        """
        return self.output_predict(self.sample_from_hidden(key, x))
    
    def __call__(self, key: Key, x: Float[Array, "N D"]) -> GaussianDistribution:
        raise self.predict(key, x)


@dataclass
class ResidualDeepGP(AbstractDeepGP):

    def sample_from_hidden(self, key: Key, x: Float[Array, "N D"]) -> Float[Array, "S N D"]:
        """
        Predict through the hidden layers.

        Args:
            x (Float[Array, "N D"]): The input data. Either of shape [N D] or [S N D].
        """
        x_shape = jnp.broadcast_shapes(x.shape, (self.num_samples, 1, 1))
        x = jnp.broadcast_to(x, x_shape)

        def step(key: Key, layer, x: Array) -> Array:
            v = sample_from_marginal(key=key, model=layer, x=x)
            return jax.vmap(jax.vmap(expmap_sph, in_axes=(0, 0)), in_axes=(0, 0))(x, v)

        key_per_hidden_layer = jax.random.split(key, self.num_layers - 1)
        for layer, key in zip(self.hidden_layers, key_per_hidden_layer):
            x = step(key, layer, x)
        return x
    
    def pathwise_sample_from_hidden(self, key: Key, x: Float[Array, "N D"]) -> Float[Array, "S N D"]:
        """
        Predict through the hidden layers.

        Args:
            x (Float[Array, "N D"]): The input data. Either of shape [N D] or [S N D].
        """
        x_shape = jnp.broadcast_shapes(x.shape, (self.num_samples, 1, 1))
        x = jnp.broadcast_to(x, x_shape)
        def step(key, layer, x: Array) -> Array:
            v = layer.pathwise_sample(key, x, self.num_samples)
            return jax.vmap(jax.vmap(expmap_sph, in_axes=(0, 0)), in_axes=(0, 0))(x, v)

        key_per_hidden_layer = jax.random.split(key, self.num_layers - 1)
        for layer, key in zip(self.hidden_layers, key_per_hidden_layer):
            x = step(key, layer, x)
        return x
    
    def pathwise_sample(self, key: Key, x: Float[Array, "N D"]) -> Float[Array, "S N D"]:
        hidden_key, output_key = jax.random.split(key)

        x = self.pathwise_sample_from_hidden(hidden_key, x)
        return self.output_layer.pathwise_sample(output_key, x, self.num_samples)
    


class DeepVectorELBO(gpjax.objectives.AbstractObjective):
    def step(
        self,
        key: Key, 
        variational_family: AbstractDeepGP,
        train_data: gpjax.Dataset,
    ) -> ScalarFloat:
        r"""Compute the evidence lower bound of a variational approximation.

        Compute the evidence lower bound under this model. In short, this requires
        evaluating the expectation of the model's log-likelihood under the variational
        approximation. To this, we sum the KL divergence from the variational posterior
        to the prior. When batching occurs, the result is scaled by the batch size
        relative to the full dataset size.

        Args:
            variational_family (AbstractVariationalFamily): The variational
                approximation for whose parameters we should maximise the ELBO with
                respect to.
            train_data (Dataset): The training data for which we should maximise the
                ELBO with respect to.

        Returns
        -------
            ScalarFloat: The evidence lower bound of the variational approximation for
                the current model parameter set.
        """
        # KL[q(f(·)) || p(f(·))]
        kl = variational_family.prior_kl()

        # ∫[log(p(y|f(·))) q(f(·))] df(·)
        var_exp = deep_vector_variational_expectation(key, variational_family, train_data)

        # For batch size b, we compute  n/b * Σᵢ[ ∫log(p(y|f(xᵢ))) q(f(xᵢ)) df(xᵢ)] - KL[q(f(·)) || p(f(·))]
        return self.constant * (
            jnp.sum(var_exp)
            * variational_family.output_layer.posterior.likelihood.num_datapoints
            / train_data.n
            - kl
        )


@jax.jit 
def moments(model: AbstractVectorSHF, x: Array) -> tuple[Array, Array]:
    def mean_and_covariance(x):
        pf = model(x)
        py = model.posterior.likelihood(pf) # FIXME This won't work for the prior 
        return py.mean(), py.covariance()
    return jax.vmap(mean_and_covariance)(x[:, None])


def deep_vector_variational_expectation(
    key: Key, 
    variational_family: AbstractDeepGP,
    train_data: gpjax.Dataset,
) -> Float[Array, " N"]:
    r"""Compute the variational expectation.

    Compute the expectation of our model's log-likelihood under our variational
    distribution. Batching can be done here to speed up computation.

    Args:
        variational_family (AbstractVariationalFamily): The variational family that we
            are using to approximate the posterior.
        train_data (Dataset): The batch for which the expectation should be computed
            for.

    Returns
    -------
        Array: The expectation of the model's log-likelihood under our variational
            distribution.
    """
    # Unpack training batch
    x, y = train_data.X, train_data.y # [N, D] [N, O]

    # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·))
    q = variational_family
    x = q.sample_from_hidden(key, x)

    # reshape because samples 
    num_samples = x.shape[0]
    y = jnp.broadcast_to(y, (num_samples, *y.shape)).reshape(-1, y.shape[-1])
    x = x.reshape(-1, x.shape[-1]) # [S N D] -> [S * N D]

    # Compute variational mean, μ(x), and variance, diag(Σ(x, x)), at the training
    # inputs, x
    mean, covariance = moments(q.output_layer, x) # [S * N O], [S * N O O]

    # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x)
    # There is no need to handle likelihoods of different samples in some special way, 
    # since likelihood of mixture is the mixture of likelihoods
    expectation = q.output_layer.posterior.likelihood.expected_log_likelihood(
        y, mean, covariance
    )
    return expectation / num_samples # MC estimate of the inner expectation requires dividing by the number of samples


def moments_unconditional(model, x):
    def mean_and_covariance(x):
        pf = model(x)
        py = model.posterior.likelihood(pf) # FIXME This won't work for the prior 
        return py.mean(), py.covariance()
    return jax.vmap(mean_and_covariance)(x[:, None])


def moments_deep(key: Key, model: AbstractDeepGP, x):
    x = model.sample_from_hidden(key, x)
    means, covs = jax.vmap(lambda t: moments_unconditional(model.output_layer, t))(x) # map over sample dimension
    return means, covs 


def pathwise_moments_deep(key: Key, model: AbstractDeepGP, x):
    x = model.pathwise_sample_from_hidden(key, x)
    means, covs = jax.vmap(lambda t: moments_unconditional(model.output_layer, t))(x) # map over sample dimension
    return means, covs


def mse(y_true: Array, y_pred: Array) -> Array:
    return jnp.mean(jnp.sum(jnp.square(y_true - y_pred), axis=-1))


def pred_nll(y_true, y_pred, std_pred):
    return -jnp.mean(
        tfd.MultivariateNormalFullCovariance(loc=y_pred, covariance_matrix=std_pred).log_prob(y_true)
    )


def evaluate(key: Key, model, test_data: VectorDataset):
    x_test, y_test = test_data.X, test_data.y
    mean, cov = moments_deep(key, model, x_test)
    return {
        'mse': mse(y_test, mean).item(), 
        'pnll': pred_nll(y_test, mean, cov).item(),
    }


import jax 
import jax.numpy as jnp
import numpy as np
import pandas as pd
import netCDF4

from jaxtyping import Array
from jax.tree_util import Partial 
import plotly.express as px 


def sphere_uniform_grid(n: int) -> Array:
    # Fibonacci lattice method 
    phi = (1 + jnp.sqrt(5)) / 2  # Golden ratio
    
    indices = jnp.arange(n)
    theta = 2 * jnp.pi * indices / phi
    phi = jnp.arccos(1 - 2 * (indices + 0.5) / n)
    return sph_to_car(jnp.column_stack((phi, theta)))


@jax.tree_util.Partial(jax.jit, static_argnames=("with_replacement",))
def closest_point_mask(targets: Array, x: Array, with_replacement: bool) -> Array:
    """
    Args: 
        targets (Array): targets in cartesian coordinates.
        x (Array): points in cartesian coordinates for which to produce the mask.
    """

    # Can do euclidean squared distance instead of spherical, since minimisation is invariant to monotonic transformations
    distances = jnp.sum((targets[:, None] - x[None, :]) ** 2, -1)

    def closest_point_mask_with_replacement():
        return jnp.argmin(distances, axis=1)
    
    def closest_point_mask_without_replacement():
        num_targets = targets.shape[0]
        closest_indices = jnp.zeros(num_targets, dtype=jnp.int64)
        available_mask = jnp.ones(x.shape[0], dtype=bool)

        for i in range(num_targets):
            masked_distances = jnp.where(available_mask, distances[i], jnp.inf)
            closest_idx = jnp.argmin(masked_distances)
            closest_indices = closest_indices.at[i].set(closest_idx)
            available_mask = available_mask.at[closest_idx].set(False)

        return closest_indices
    
    mask_indices = jax.lax.cond(
        with_replacement, 
        closest_point_mask_with_replacement, 
        closest_point_mask_without_replacement,
    )
    mask = jnp.zeros(x.shape[0], dtype=jnp.bool)
    return mask.at[mask_indices].set(True)


def angles_to_radians_colat(x: Array) -> Array:
    return jnp.pi * x / 180 + jnp.pi / 2

def angles_to_radians_lon(x: Array) -> Array:
    return jnp.pi * x / 180 

def angles_to_radians(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: angles_to_radians_colat(df.colat),
        lon=lambda df: angles_to_radians_lon(df.lon),
    )

def radians_to_angles_colat(x: Array) -> Array:
    return 180 * x / jnp.pi - 90 

def radians_to_angles_lon(x: Array) -> Array:
    return 180 * x / jnp.pi 

def radians_to_angles(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: radians_to_angles_colat(df.colat),
        lon=lambda df: radians_to_angles_lon(df.lon),
    )


@jax.jit
def sph_to_car(sph: Array) -> Array:
    """
    Args: 
        sph (Array): points in spherical coordinates.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


def sphere_meshgrid(n: int) -> Float:
    """
    Create a meshgrid on the sphere.
    """
    colat = jnp.linspace(0, jnp.pi, n)
    lon = jnp.linspace(0, 2 * jnp.pi, n)
    colat, lon = jnp.meshgrid(colat, lon, indexing="ij")
    return jnp.stack([colat, lon], axis=-1)


era5_file_path = "../data/era5.nc"
era5_dataset = netCDF4.Dataset(era5_file_path,'r')
era5_lon = angles_to_radians_lon(era5_dataset.variables['longitude'][:].data.astype(np.float64))
era5_colat = angles_to_radians_colat(era5_dataset.variables['latitude'][:].data.astype(np.float64))
era5_lon_mesh, era5_colat_mesh = jnp.meshgrid(era5_lon, era5_colat)


def read_era5(time: int, level: int) -> pd.DataFrame:
    level = {
        0: 2, 
        7: 1, 
        15: 0, 
    }[level]
    u = era5_dataset.variables['u'][time, level].data.astype(np.float64)
    v = era5_dataset.variables['v'][time, level].data.astype(np.float64)
    df = pd.DataFrame({
        "lon": era5_lon_mesh.flatten(),
        "colat": era5_colat_mesh.flatten(),
        "u": u.flatten(),
        "v": v.flatten(),
    })
    return df


def match_to_uniform_grid_mask(x: Array, n: int, with_replacement: bool = True) -> Array:
    """
    Args: 
        x (Array): Points in cartesian coordinates for which to create the mask.
    """
    return closest_point_mask(
        targets=sphere_uniform_grid(n),
        x=x,
        with_replacement=with_replacement,
    )


def to_test_dataframe(df: pd.DataFrame, n: int, with_replacement: bool = True) -> pd.DataFrame:
    sph = df[['colat', 'lon']].values
    mask = match_to_uniform_grid_mask(
        x=sph_to_car(sph), n=n, with_replacement=with_replacement,
    ).tolist()
    return df[mask]



import pandas as pd 
import math 
from datetime import datetime, timedelta
from skyfield.api import load, EarthSatellite, utc
from skyfield.toposlib import wgs84


def datetime_range(start, stop, step=timedelta(minutes=1)):
    current = start
    while current < stop:
        yield current
        current += step


def load_aeolus_and_timescale():
    ts = load.timescale()

    # Aeolus TLE data
    line1 = "1 43600U 18066A   21153.73585495  .00031128  00000-0  12124-3 0  9990"
    line2 = "2 43600  96.7150 160.8035 0006915  90.4181 269.7884 15.87015039160910"

    aeolus = EarthSatellite(line1, line2, "AEOLUS", ts)
    return aeolus, ts


def read_aeolus(start: datetime, stop: datetime, step=timedelta(minutes=1)) -> pd.DataFrame:
    if start.tzinfo is None:
        start = start.replace(tzinfo=utc)
    if stop.tzinfo is None:
        stop = stop.replace(tzinfo=utc)

    aeolus, ts = load_aeolus_and_timescale()
    time = list(datetime_range(start, stop, step))
    lat, lon = wgs84.latlon_of(aeolus.at(ts.from_datetimes(time)))

    # convert to colatitude [0, pi] and longitude [0, 2pi]
    colat, lon = lat.radians + math.pi / 2, lon.radians + math.pi

    return pd.DataFrame({
        "time": time,
        "colat": colat,
        "lon": lon,
    })

def to_train_dataframe(aeolus: pd.DataFrame, era5: pd.DataFrame, with_replacement: bool = True) -> tuple[pd.DataFrame, pd.DataFrame]:
    targets = sph_to_car(aeolus[['colat', 'lon']].values)
    x = sph_to_car(era5[['colat', 'lon']].values)
    mask = closest_point_mask(
        targets=targets, 
        x=x, 
        with_replacement=with_replacement,
    ).tolist()
    return era5[mask], era5[~np.array(mask)]


def to_train_test_dataframes(aeolus: pd.DataFrame, era5: pd.DataFrame, test_size: int, with_replacement: bool = True) -> tuple[pd.DataFrame, pd.DataFrame]:
    train_df, rest_df = to_train_dataframe(aeolus=aeolus, era5=era5, with_replacement=with_replacement)
    test_df = to_test_dataframe(rest_df, n=test_size, with_replacement=with_replacement)
    return train_df, test_df


def train_test_sets(
    time: int, 
    level: int, 
    start: datetime, 
    stop: datetime, 
    step: timedelta, 
    test_size: int, 
    with_replacement: bool = True,
) -> tuple[Array, Array, Array, Array]:
    aeolus = read_aeolus(start=start, stop=stop, step=step)
    era5 = read_era5(time, level)
    
    # split data
    df_train, df_test = to_train_test_dataframes(
        aeolus=aeolus, era5=era5, test_size=test_size, with_replacement=with_replacement
    )

    # Inputs and target
    X_train, X_test = df_train[["colat", "lon"]].to_numpy(), df_test[["colat", "lon"]].to_numpy()
    y_train, y_test = df_train[["v", "u"]].to_numpy(), df_test[["v", "u"]].to_numpy()

    # Convert to jnp arrays (not sure if this is necessary)
    X_train, X_test = jnp.array(X_train), jnp.array(X_test)
    y_train, y_test = jnp.array(y_train), jnp.array(y_test)

    # Normalize (sort of) targets
    norm_constant = jnp.mean(jax.vmap(jnp.linalg.norm)(y_train))
    y_train /= norm_constant
    y_test /= norm_constant
    return X_train, X_test, y_train, y_test


def build_layers(
    num_layers: int,
    hidden_kernel: type[AbstractVectorKernel], 
    output_kernel: type[AbstractVectorKernel],
    likelihood: VectorGaussian,
    hidden_variance: float = 0.01,
    kappa: float = 1.0,
    max_ell_variational: int = 9,
    max_ell_prior: int = 30, 
) -> list[AbstractVectorSHF]:    
    layers = []

    # hidden layers 
    hidden_variational_family = variational_family_from_kernel(hidden_kernel)
    for _ in range(num_layers - 1):
        kernel = hidden_kernel(variance=hidden_variance, max_ell=max_ell_prior, kappa=kappa)
        mean_function = VectorZeroMean()
        prior = gpjax.gps.Prior(kernel=kernel, mean_function=mean_function)
        posterior = IdentityPosterior(prior=prior)
        layer = hidden_variational_family(posterior=posterior, max_ell=max_ell_variational)
        layers.append(layer)
    
    # output layer
    output_variational_family = variational_family_from_kernel(output_kernel)
    kernel = output_kernel(max_ell=max_ell_prior, kappa=kappa)
    mean_function = VectorZeroMean()
    prior = gpjax.gps.Prior(kernel=kernel, mean_function=mean_function)
    posterior = prior * likelihood
    layer = output_variational_family(posterior=posterior, max_ell=max_ell_variational)
    layers.append(layer)

    return layers


def plot_results(x_train, y_train, x_test, y_test, mean, var, var_x, sample, history):
    var_lat, var_lon = var_x[..., 0], var_x[..., 1]
    x_train_lat, x_train_lon = x_train[:, 0], x_train[:, 1]
    y_train_dlat, y_train_dlon = y_train[:, 0], y_train[:, 1]
    x_test_lat, x_test_lon = x_test[:, 0], x_test[:, 1]
    y_test_dlat, y_test_dlon = y_test[:, 0], y_test[:, 1]
    mean_dlat, mean_dlon = mean[:, 0], mean[:, 1]


    nrows = 3
    ncols = 2
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 4), layout="constrained")

    
    # top left, prediction with test data 
    q = axs[0][0].quiver(x_test_lon, x_test_lat, mean_dlon, mean_dlat, angles="uv")
    q._init()
    scale = q.scale
    axs[0][0].quiver(x_train_lon, x_train_lat, y_train_dlon, y_train_dlat, angles="uv", color="red", scale=scale)
    axs[0][0].set_xlabel("lon")
    axs[0][0].set_ylabel("lat")
    axs[0][0].set_title("Predictive Mean")
    
    # top right, uncertainty
    c = axs[0][1].pcolormesh(var_lon, var_lat, var, vmin=var.min(), vmax=var.max())
    fig.colorbar(c, ax=axs[0][1])
    axs[0][1].set_xlabel("lon")
    axs[0][1].set_ylabel("lat")
    axs[0][1].set_title("Predictive Uncertainty")

    # middle left, true test data
    axs[1][0].quiver(x_test_lon, x_test_lat, y_test_dlon, y_test_dlat, angles="uv", scale=scale)
    axs[1][0].set_xlabel("lon")
    axs[1][0].set_ylabel("lat")
    axs[1][0].set_title("Ground truth")

    # middle right, difference 
    y_diff = y_test - mean
    y_diff_dlat, y_diff_dlon = y_diff[:, 0], y_diff[:, 1]
    axs[1][1].quiver(x_test_lon, x_test_lat, y_diff_dlon, y_diff_dlat, angles="uv", scale=scale)
    axs[1][1].set_xlabel("lon")
    axs[1][1].set_ylabel("lat")
    axs[1][1].set_title("Prediction Error")

    # bottom left, sample from posterior 
    q = axs[2][0].quiver(x_test_lon, x_test_lat, sample[:, 0], sample[:, 1], angles="uv", scale=scale)
    axs[2][0].set_xlabel("lon")
    axs[2][0].set_ylabel("lat")
    axs[2][0].set_title("Sample from Posterior")

    # bottom right, training history
    axs[2][1].plot(history)
    axs[2][1].set_xlabel("Iteration")
    axs[2][1].set_ylabel("Negative ELBO")
    axs[2][1].set_title("Training History")

    # plt.tight_layout()
    return fig


# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from gpjax.fit import (
    _check_batch_size,
    _check_log_rate,
    _check_model,
    _check_num_iters,
    _check_optim,
    _check_train_data,
    _check_verbose,
    get_batch,
)
from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
from jax import (
    jit,
    value_and_grad,
)
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import optax as ox
import scipy

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit_deep(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    train_data: Dataset,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = -1,
    log_rate: Optional[int] = 10,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
    safe: Optional[bool] = True,
) -> Tuple[ModuleModel, Array]:
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """
    if safe:
        # Check inputs.
        _check_model(model)
        _check_train_data(train_data)
        _check_optim(optim)
        _check_num_iters(num_iters)
        _check_batch_size(batch_size)
        _check_prng_key("fit", key)
        _check_log_rate(log_rate)
        _check_verbose(verbose)

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(key: Key, model: Module, batch: Dataset) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(key, model.constrain(), batch)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size != -1:
            batch = get_batch(train_data, batch_size, key)
        else:
            batch = train_data

        loss_val, loss_gradient = jax.value_and_grad(loss, argnums=1)(key, model, batch)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)

    # Constrained space.
    model = model.constrain()

    return model, history


import argparse
import os 



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--time", type=int, default=0)
    parser.add_argument("--level", type=int, default=15, choices=[0, 7, 15])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--max_ell_prior", type=int, default=9)
    parser.add_argument("--max_ell_variational", type=int, default=9)
    parser.add_argument("--total_hidden_variance", type=float, default=0.0001)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--num_iters", type=int, default=1000)
    parser.add_argument("--num_layers", type=int, default=1)
    parser.add_argument("--num_samples", type=int, default=3)
    parser.add_argument("--test_size", type=int, default=5000)
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--save_dir", type=str, default="./results/wind")
    parser.add_argument("--num_hours", type=int, default=24)
    parser.add_argument("--step_minutes", type=int, default=1)
    parser.add_argument("--num_test_samples", type=int, default=10)
    args = parser.parse_args()


    seed = args.seed
    time = args.time
    level = args.level
    max_ell_prior = args.max_ell_prior
    max_ell_variational = args.max_ell_variational
    total_hidden_variance = args.total_hidden_variance
    lr = args.lr
    num_iters = args.num_iters
    num_layers = args.num_layers
    num_samples = args.num_samples
    test_size = args.test_size
    batch_size = args.batch_size
    save_dir = args.save_dir
    num_hours = args.num_hours
    step_minutes = args.step_minutes
    num_test_samples = args.num_test_samples


    ### RANDOMNESS 
    key = jax.random.key(seed)
    data_key, train_key, test_key, plot_key = jax.random.split(key, 4)


    ### DATA     
    n_plot_uncertainty = 100

    # aeolus track
    start = datetime(2019, 1, 1, 9)
    stop = start + timedelta(hours=num_hours)
    step = timedelta(minutes=step_minutes)
    train_size = (stop - start) // step 
    with_replacement = True

    # load 
    X_train, X_test, y_train, y_test = train_test_sets(time, level, start, stop, step, test_size, with_replacement)
    train_data = VectorDataset(X_train, y_train)
    test_data = VectorDataset(X_test, y_test)
    X_uncertainty = sphere_meshgrid(n_plot_uncertainty).reshape(-1, 2)


    ### MODEL
    # settings 
    kappa = 1.0
    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    obs_variance = 1.0
    obs_stddev = obs_variance ** 0.5
    hidden_kernel = HodgeMaternKernel
    output_kernel = HodgeMaternKernel

    experiment_name = f"{time=}-{level=}-{step_minutes=}-{num_layers=}-{seed=}-{max_ell_variational=}-{num_test_samples=}-{max_ell_prior=}-{num_samples=}-{total_hidden_variance=}-{num_iters=}-{lr=}-{num_hours=}"
    print(f"Running experiment: {experiment_name}")
    # build 
    likelihood = VectorGaussian(num_datapoints=train_data.n, obs_stddev=obs_stddev)
    layers = build_layers(
        num_layers=num_layers,
        hidden_kernel=hidden_kernel,
        output_kernel=output_kernel,
        likelihood=likelihood,
        hidden_variance=hidden_variance,
        kappa=kappa,
        max_ell_variational=max_ell_variational,
        max_ell_prior=max_ell_prior,
    )
    model = ResidualDeepGP(layers=layers, num_samples=num_samples)


    ### FIT
    # train 
    objective = jax.jit(DeepVectorELBO(negative=True))
    optim = optax.adam(learning_rate=lr)
    model_opt, history = fit_deep(
        model=model,
        objective=objective,
        train_data=train_data,
        optim=optim,
        num_iters=num_iters,
        key=train_key,
        batch_size=batch_size,
    )


    # test
    model_opt = model_opt.replace(num_samples=num_test_samples)

    test_metrics = evaluate(test_key, model_opt, test_data)
    metrics_str = ", ".join(f"{k}: {v:.3f}" for k, v in test_metrics.items())
    print(f"Metrics: {metrics_str}")


    ### SAVE RESULTS 

    # plot 
    mean_test, _ = moments_deep(plot_key, model_opt, X_test)
    mean_test = jnp.mean(mean_test, axis=0)
    _, cov_plot = moments_deep(plot_key, model_opt, X_uncertainty)
    uncertainty = jax.vmap(jax.vmap(jnp.linalg.norm))(cov_plot) # we define uncertainty as the average norm of the covariance matrices in the mixture
    uncertainty = jnp.mean(uncertainty, axis=0)
    uncertainty = uncertainty.reshape(n_plot_uncertainty, n_plot_uncertainty)

    model_opt = model_opt.replace(num_samples=1)
    sample_test = model_opt.pathwise_sample(plot_key, X_test).squeeze()
    X_uncertainty = X_uncertainty.reshape(n_plot_uncertainty, n_plot_uncertainty, 2)

    fig = plot_results(X_train, y_train, X_test, y_test, mean_test, uncertainty, X_uncertainty, sample_test, history)

    # save
    experiment_dir = os.path.join(save_dir, experiment_name)
    os.makedirs(experiment_dir, exist_ok=True)

    # metrics 
    metrics_path = os.path.join(experiment_dir, "metrics.csv")
    pd.DataFrame(test_metrics, index=[0]).to_csv(metrics_path, index=False)

    # figure 
    fig_path = os.path.join(experiment_dir, "results.pdf")
    fig.suptitle(f"Model: {experiment_name}, Metrics: {metrics_str}", fontsize=6)
    fig.savefig(fig_path, bbox_inches="tight")
    