from jax import config 
config.update("jax_enable_x64", True)


import jax 
import jax.numpy as jnp
import gpjax 
import optax 
from gpjax.distributions import GaussianDistribution
from gpjax.typing import Float
from gpjax.base import static_field, param_field
from jax.tree_util import Partial
import tensorflow_probability.substrates.jax as tfp
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Key

from matplotlib import pyplot as plt 

from dataclasses import dataclass, InitVar
from abc import abstractmethod


@jax.jit
def sph_dot_product(sph1: Float, sph2: Float) -> Float:
    """
    Computes dot product in R^3 of two points on the sphere in spherical coordinates.
    """
    colat1, lon1 = sph1[..., 0], sph1[..., 1]
    colat2, lon2 = sph2[..., 0], sph2[..., 1]
    return jnp.sin(colat1) * jnp.sin(colat2) * jnp.cos(lon1 - lon2) + jnp.cos(colat1) * jnp.cos(colat2)


# NOTE jitting this doesn't help
def sph_gegenbauer(x, y, max_ell: int, alpha: float = 0.5):
    return gegenbauer(x=sph_dot_product(x, y), max_ell=max_ell, alpha=alpha)


# NOTE jitting this doesn't help
def sph_gegenbauer_single(x, y, ell: int, alpha: float = 0.5):
    return gegenbauer_single(x=sph_dot_product(x, y), ell=ell, alpha=alpha)


def array(x):
    return jnp.array(x, dtype=jnp.float64)


@jax.jit
def sph_to_car(sph):
    """
    From spherical (colat, lon) coordinates to cartesian, single point.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    lon = (lon + 2 * jnp.pi) % (2 * jnp.pi)
    return jnp.stack([colat, lon], axis=-1)


"""
Conversion between hodge and flat coordinates.
"""

@jax.jit
def flatten_matrix(matrix):
    """
    Input matrix has shape (nx, ny, *block_shape).
    """
    out = jnp.vstack([jnp.hstack([block for block in row_blocks]) for row_blocks in matrix])
    return out


@Partial(jax.jit, static_argnames=('spherical',))
def unflatten_matrix(matrix, spherical=True):
    """
    Input matrix has shape (nx, ny).
    """
    dim = 2 if spherical else 3
    out = jnp.array([
        jnp.split(row_block, indices_or_sections=matrix.shape[1]//dim, axis=1)
        for row_block in jnp.split(matrix, indices_or_sections=matrix.shape[0]//dim, axis=0)
    ])
    return out


def flatten_coord(coord):
    """
    Flatten coordinates to 1d array.
    """
    return jnp.ravel(coord)


def unflatten_coord(coord_flat, spherical=True, extra_dims=[]):
    """
    Un-flatten coordinates to (n, 2, *extra_dims).
    """
    return coord_flat.reshape(-1, 2 if spherical else 3, *extra_dims)


from pathlib import Path
from typing import Callable

import numpy as np
from jax import Array


class FundamentalSystemNotPrecomputedError(ValueError):

    def __init__(self, dimension: int):
        message = f"Fundamental system for dimension {dimension} has not been precomputed."
        super().__init__(message)


def fundamental_set_loader(dimension: int, load_dir="fundamental_system") -> Callable[[int], Array]:
    load_dir = Path("../") / load_dir
    file_name = load_dir / f"fs_{dimension}D.npz"

    cache = {}
    if file_name.exists():
        with np.load(file_name) as f:
            cache = {k: v for (k, v) in f.items()}
    else:
        raise FundamentalSystemNotPrecomputedError(dimension)

    def load(degree: int) -> Array:
        key = f"degree_{degree}"
        if key not in cache:
            raise ValueError(f"key: {key} not in cache.")
        return cache[key]

    return load


@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))
def gegenbauer(x: Float[Array, "N D"], max_ell: int, alpha: float = 0.5) -> Float[Array, "N L"]:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x
    
    res = jnp.empty((*x.shape, max_ell + 1), dtype=x.dtype)
    res = res.at[..., 0].set(C_0)

    def step(n: int, res_and_Cs: tuple[Float, Float, Float]) -> tuple[Float, Float, Float]:
        res, C, C_prev = res_and_Cs
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        res = res.at[..., n].set(C)
        return res, C, C_prev
    
    return jax.lax.cond(
        max_ell == 0,
        lambda: res,
        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[..., 1].set(C_1), C_1, C_0))[0],
    )


@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call 
def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x

    def step(Cs_and_n):
        C, C_prev, n = Cs_and_n
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        return C, C_prev, n + 1

    def cond(Cs_and_n):
        n = Cs_and_n[2]
        return n <= ell

    return jax.lax.cond(
        ell == 0,
        lambda: C_0,
        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],
    )


@dataclass
class SphericalHarmonics(gpjax.Module):
    """
    Spherical harmonics inducing features for sparse inference in Gaussian processes.

    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,
    as such, they form an orthogonal basis.

    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute
    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via
    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during
    initialisation.

    Attributes:
        num_frequencies: The number of frequencies, up to which, we compute the harmonics.

    Returns:
        An instance of the spherical harmonics features.
    """

    max_ell: int = static_field()
    sphere_dim: int = static_field()
    alpha: float = static_field(init=False)
    orth_basis: Array = static_field(init=False)
    Vs: list[Array] = static_field(init=False)
    num_phases_per_frequency: Float[Array, " L"] = static_field(init=False)
    num_phases: int = static_field(init=False)


    @property
    def levels(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)
    

    def __post_init__(self) -> None:
        """
        Initialise the parameters of the spherical harmonic features and return a `Param` object.

        Returns:
            None
        """
        dim = self.sphere_dim + 1

        # Try loading a pre-computed fundamental set.
        fund_set = fundamental_set_loader(dim)

        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.
        self.alpha = (dim - 2.0) / 2.0

        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.
        self.Vs = [fund_set(n) for n in self.levels]

        # pre-compute and save the orthogonal basis 
        self.orth_basis = self._orthogonalise_basis()


        # set these things instead of computing every time 
        self.num_phases_per_frequency = [v.shape[0] for v in self.Vs]
        self.num_phases = sum(self.num_phases_per_frequency)


    @property
    def Ls(self) -> list[Array]:
        """
        Alias for the orthogonal basis at every frequency.
        """
        return self.orth_basis

    def _orthogonalise_basis(self) -> None:
        """
        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.
        """
        alpha = self.alpha
        levels = jnp.split(self.levels, self.max_ell + 1)
        const = alpha / (alpha + self.levels.astype(jnp.float64))
        const = jnp.split(const, self.max_ell + 1)

        def _func(v, n, c):
            x = jnp.matmul(v, v.T)
            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)
            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))
            return L

        return jax.tree.map(_func, self.Vs, levels, const)

    def custom_gegenbauer_single(self, x, ell, alpha):
        return gegenbauer(x, self.max_ell, alpha)[..., ell]

    @jax.jit
    def polynomial_expansion(self, X: Float[Array, "N D"]) -> Float[Array, "M N"]:
        """
        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.

        Args:
            X: Input Array.

        Returns:
            The harmonics evaluated at the input as a polynomial expansion of the basis.
        """
        levels = jnp.split(self.levels, self.max_ell + 1)

        def _func(v, n, L):
            vxT = jnp.dot(v, X.T)
            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)
            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)
            return harmonic

        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)
        return jnp.concatenate(harmonics, axis=0)
    
    def __eq__(self, other: "SphericalHarmonics") -> bool:
        """
        Check if two spherical harmonic features are equal.

        Args:
            other: The other spherical harmonic features.

        Returns:
            A boolean indicating if the two features are equal.
        """
        # Given the first two parameters, the rest are deterministic. 
        # The user must not mutate all other fields, but that is not enforced for now.
        return (
            self.max_ell == other.max_ell 
            and self.sphere_dim == other.sphere_dim 
        )    


import warnings 
from typing import Optional
from jaxtyping import Num
from gpjax.typing import ScalarFloat


def mse(y_true: Array, y_pred: Array) -> Array:
    return jnp.mean(jnp.square(y_true - y_pred))


def nlpd(y_true, y_pred, std_pred):
    return -jnp.mean(
        tfd.Normal(loc=y_pred, scale=std_pred).log_prob(y_true)
    )


import jax 
import jax.numpy as jnp
import numpy as np
import pandas as pd
import netCDF4

from jaxtyping import Array
from jax.tree_util import Partial 
import plotly.express as px 


def sphere_uniform_grid(n: int) -> Array:
    # Fibonacci lattice method 
    phi = (1 + jnp.sqrt(5)) / 2  # Golden ratio
    
    indices = jnp.arange(n)
    theta = 2 * jnp.pi * indices / phi
    phi = jnp.arccos(1 - 2 * (indices + 0.5) / n)
    return sph_to_car(jnp.column_stack((phi, theta)))


@jax.tree_util.Partial(jax.jit, static_argnames=("with_replacement",))
def closest_point_mask(targets: Array, x: Array, with_replacement: bool) -> Array:
    """
    Args: 
        targets (Array): targets in cartesian coordinates.
        x (Array): points in cartesian coordinates for which to produce the mask.
    """

    # Can do euclidean squared distance instead of spherical, since minimisation is invariant to monotonic transformations
    distances = jnp.sum((targets[:, None] - x[None, :]) ** 2, -1)

    def closest_point_mask_with_replacement():
        return jnp.argmin(distances, axis=1)
    
    def closest_point_mask_without_replacement():
        num_targets = targets.shape[0]
        closest_indices = jnp.zeros(num_targets, dtype=jnp.int64)
        available_mask = jnp.ones(x.shape[0], dtype=bool)

        for i in range(num_targets):
            masked_distances = jnp.where(available_mask, distances[i], jnp.inf)
            closest_idx = jnp.argmin(masked_distances)
            closest_indices = closest_indices.at[i].set(closest_idx)
            available_mask = available_mask.at[closest_idx].set(False)

        return closest_indices
    
    mask_indices = jax.lax.cond(
        with_replacement, 
        closest_point_mask_with_replacement, 
        closest_point_mask_without_replacement,
    )
    mask = jnp.zeros(x.shape[0], dtype=jnp.bool)
    return mask.at[mask_indices].set(True)


def angles_to_radians_colat(x: Array) -> Array:
    return jnp.pi * x / 180 + jnp.pi / 2

def angles_to_radians_lon(x: Array) -> Array:
    return jnp.pi * x / 180 

def angles_to_radians(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: angles_to_radians_colat(df.colat),
        lon=lambda df: angles_to_radians_lon(df.lon),
    )

def radians_to_angles_colat(x: Array) -> Array:
    return 180 * x / jnp.pi - 90 

def radians_to_angles_lon(x: Array) -> Array:
    return 180 * x / jnp.pi 

def radians_to_angles(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: radians_to_angles_colat(df.colat),
        lon=lambda df: radians_to_angles_lon(df.lon),
    )


def match_to_uniform_grid_mask(x: Array, n: int, with_replacement: bool = True) -> Array:
    """
    Args: 
        x (Array): Points in cartesian coordinates for which to create the mask.
    """
    return closest_point_mask(
        targets=sphere_uniform_grid(n),
        x=x,
        with_replacement=with_replacement,
    )


def to_test_dataframe(df: pd.DataFrame, n: int, with_replacement: bool = True) -> pd.DataFrame:
    sph = df[['colat', 'lon']].values
    mask = match_to_uniform_grid_mask(
        x=sph_to_car(sph), n=n, with_replacement=with_replacement,
    ).tolist()
    return df[mask]



import pandas as pd 


def plot_results(x_train, y_train, x_test, y_test, mean, var, var_x, sample, history):
    var_lat, var_lon = var_x[..., 0], var_x[..., 1]
    x_train_lat, x_train_lon = x_train[:, 0], x_train[:, 1]
    y_train_dlat, y_train_dlon = y_train[:, 0], y_train[:, 1]
    x_test_lat, x_test_lon = x_test[:, 0], x_test[:, 1]
    y_test_dlat, y_test_dlon = y_test[:, 0], y_test[:, 1]
    mean_dlat, mean_dlon = mean[:, 0], mean[:, 1]


    nrows = 3
    ncols = 2
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 4), layout="constrained")

    
    # top left, prediction with test data 
    q = axs[0][0].quiver(x_test_lon, x_test_lat, mean_dlon, mean_dlat, angles="uv")
    q._init()
    scale = q.scale
    axs[0][0].quiver(x_train_lon, x_train_lat, y_train_dlon, y_train_dlat, angles="uv", color="red", scale=scale)
    axs[0][0].set_xlabel("lon")
    axs[0][0].set_ylabel("lat")
    axs[0][0].set_title("Predictive Mean")
    
    # top right, uncertainty
    c = axs[0][1].pcolormesh(var_lon, var_lat, var, vmin=var.min(), vmax=var.max())
    fig.colorbar(c, ax=axs[0][1])
    axs[0][1].set_xlabel("lon")
    axs[0][1].set_ylabel("lat")
    axs[0][1].set_title("Predictive Uncertainty")

    # middle left, true test data
    axs[1][0].quiver(x_test_lon, x_test_lat, y_test_dlon, y_test_dlat, angles="uv", scale=scale)
    axs[1][0].set_xlabel("lon")
    axs[1][0].set_ylabel("lat")
    axs[1][0].set_title("Ground truth")

    # middle right, difference 
    y_diff = y_test - mean
    y_diff_dlat, y_diff_dlon = y_diff[:, 0], y_diff[:, 1]
    axs[1][1].quiver(x_test_lon, x_test_lat, y_diff_dlon, y_diff_dlat, angles="uv", scale=scale)
    axs[1][1].set_xlabel("lon")
    axs[1][1].set_ylabel("lat")
    axs[1][1].set_title("Prediction Error")

    # bottom left, sample from posterior 
    q = axs[2][0].quiver(x_test_lon, x_test_lat, sample[:, 0], sample[:, 1], angles="uv", scale=scale)
    axs[2][0].set_xlabel("lon")
    axs[2][0].set_ylabel("lat")
    axs[2][0].set_title("Sample from Posterior")

    # bottom right, training history
    axs[2][1].plot(history)
    axs[2][1].set_xlabel("Iteration")
    axs[2][1].set_ylabel("Negative ELBO")
    axs[2][1].set_title("Training History")

    # plt.tight_layout()
    return fig


# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from gpjax.fit import (
    _check_batch_size,
    _check_log_rate,
    _check_model,
    _check_num_iters,
    _check_optim,
    _check_train_data,
    _check_verbose,
    get_batch,
)
from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
from jax import (
    jit,
    value_and_grad,
)
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import optax as ox
import scipy

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit_deep(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    train_data: Dataset,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = -1,
    log_rate: Optional[int] = 10,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
    safe: Optional[bool] = True,
) -> Tuple[ModuleModel, Array]:
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """
    if safe:
        # Check inputs.
        _check_model(model)
        _check_train_data(train_data)
        _check_optim(optim)
        _check_num_iters(num_iters)
        _check_batch_size(batch_size)
        _check_prng_key("fit", key)
        _check_log_rate(log_rate)
        _check_verbose(verbose)

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(key: Key, model: Module, batch: Dataset) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(key, model.constrain(), batch)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size != -1:
            batch = get_batch(train_data, batch_size, key)
        else:
            batch = train_data

        loss_val, loss_gradient = jax.value_and_grad(loss, argnums=1)(key, model, batch)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)

    # Constrained space.
    model = model.constrain()

    return model, history


from gpjax.base import static_field, param_field
from gpjax.kernels import AbstractKernel
from gpjax.likelihoods import AbstractLikelihood
from gpjax.gps import AbstractPosterior
import tensorflow_probability.substrates.jax.bijectors as tfb
from jax.scipy.special import gammaln
from jaxtyping import Int


@jax.jit 
def comb(N, k) -> Int:
    return jnp.round(jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))).astype(jnp.int64)


@Partial(jax.jit, static_argnames=("sphere_dim"))
def num_phases_in_frequency(sphere_dim: int, frequency: Int) -> Int:
    l, d = frequency, sphere_dim
    return jnp.where(
        l == 0, 
        jnp.ones_like(l, dtype=jnp.int64), 
        comb(l + d - 2, l - 1) + comb(l + d - 1, l),
    )


@Partial(jax.jit, static_argnames=('max_ell', 'alpha', 'sphere_dim'))
def sphere_isotropic_kernel(spectrum: Float, z: Float, *, max_ell: int, alpha: float, sphere_dim: int) -> Float:
    c1 = num_phases_in_frequency(sphere_dim=sphere_dim, frequency=jnp.arange(max_ell + 1))
    c2 = gegenbauer(1.0, max_ell=max_ell, alpha=alpha)
    Pz = gegenbauer(z, max_ell=max_ell, alpha=alpha)
    addition_theorem = c1 / c2 * Pz
    return jnp.dot(addition_theorem, spectrum)


@Partial(jax.jit, static_argnames=('dim',))
def matern_spectrum(ell: Float, kappa: Float, nu: Float, variance: Float, dim: int) -> Float:
    lambda_ells = ell * (ell + dim - 1)
    log_Phi_nu_ells = -(nu + dim / 2) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))
    
    # Subtract max value for numerical stability
    max_log_Phi = jnp.max(log_Phi_nu_ells)
    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)
    
    # Normalize the density, so that it sums to 1
    num_harmonics_per_ell = num_phases_in_frequency(frequency=ell, sphere_dim=dim)
    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)
    return variance * Phi_nu_ells / normalizer


@jax.jit 
def pairwise_dot(x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]:
    return jnp.einsum("nd, md -> nm", x, y)

@jax.jit 
def pairwise_dot_diag(x: Float[Array, "N D"]) -> Float[Array, "N"]:
    return jnp.sum(x ** 2, axis=-1)


@dataclass
class SphereMaternKernel(Module):
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array(1.5), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    max_ell: int = static_field(25)
    alpha: float = static_field(init=False)

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)
        self.alpha = (self.sphere_dim - 1.0) / 2.0

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.float64)
    
    def spectrum(self) -> Num[Array, " L"]:
        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)
    
    @jax.jit 
    def __call__(self, z: Float[Array, "N M"]) -> Float[Array, "N M"]:
        return sphere_isotropic_kernel(self.spectrum(), z, max_ell=self.max_ell, alpha=self.alpha, sphere_dim=self.sphere_dim)

    # TODO this could inherit from a sphere isotropic kernel class and thus the multioutput kernel would also have these methods
    def prepare_inputs_full(self, x1: Float[Array, "N D"], x2: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return pairwise_dot(x1, x2)
    
    def prepare_inputs_diag(self, x: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return pairwise_dot_diag(x)
    

@jax.jit 
def pairwise_distance(x: Float[Array, "N D"], y: Float[Array, "M D"]) -> Float[Array, "N M"]:
    squared_distance = jnp.sum(jnp.square(x[:, None] - y[None]), axis=-1)
    return jnp.sqrt(jnp.clip(squared_distance, min=1e-36))


@jax.jit 
def pairwise_distance_diag(x: Float[Array, "N D"]) -> Float[Array, "N"]:
    return jnp.zeros(x.shape[0])


@jax.jit 
def euclidean_matern32_kernel(z: Float, variance: Float) -> Float:
    return variance * (1 + jnp.sqrt(3) * z) * jnp.exp(-jnp.sqrt(3) * z)


@dataclass 
class EuclideanMaternKernel32(Module):
    num_inputs: int = static_field(init=True)
    kappa: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        assert self.kappa.shape[0] == self.num_inputs or self.kappa.shape[0] == 1

    def prepare_inputs_full(self, x1: Float[Array, "N D"], x2: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return pairwise_distance(x1 / self.kappa, x2 / self.kappa)

    def prepare_inputs_diag(self, x: Float[Array, "N D"]) -> Float[Array, "N"]:
        return pairwise_distance_diag(x / self.kappa)
    
    def __call__(self, z: Float[Array, "N M"]) -> Float[Array, "N M"]:
        return euclidean_matern32_kernel(z, self.variance)


@dataclass 
class MultioutputEuclideanMaternKernel32(Module):
    num_inputs: int = static_field()
    num_outputs: int = static_field()
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)        

        # shape for multioutput
        kappa_shape = jnp.broadcast_shapes(self.kappa.shape, (self.num_outputs, 1))
        self.kappa = jnp.broadcast_to(self.kappa, kappa_shape)
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

        assert self.kappa.shape[1] == self.num_inputs or self.kappa.shape[1] == 1


    def __post_init__(self):
        self._validate_params()

    def prepare_inputs_full(self, x1: Float[Array, "N D"], x2: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return jax.vmap(lambda k: pairwise_distance(x1 / k, x2 / k))(self.kappa)

    def prepare_inputs_diag(self, x: Float[Array, "N D"]) -> Float[Array, "N"]:
        return jax.vmap(lambda k: pairwise_distance_diag(x / k))(self.kappa)

    def __call__(self, z: Float[Array, "O N M"]) -> Float[Array, "O N M"]:
        return jax.vmap(euclidean_matern32_kernel)(z, self.variance)
      

@dataclass 
class MultioutputSphereMaternKernel(Module):
    num_outputs: int = static_field()
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array([1.5]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    max_ell: int = static_field(25)
    alpha: float = static_field(init=False)

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        # shape for multioutput
        self.kappa = jnp.broadcast_to(self.kappa, (self.num_outputs,))
        self.nu = jnp.broadcast_to(self.nu, (self.num_outputs,))
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

    def __post_init__(self):
        self._validate_params()
        self.alpha = (self.sphere_dim - 1) / 2

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1)
    
    @jax.jit 
    def spectrum(self) -> Num[Array, "O L"]:
        return jax.vmap(
            lambda kappa, nu, variance: matern_spectrum(self.ells, kappa, nu, variance, dim=self.sphere_dim)
        )(self.kappa, self.nu, self.variance)
    
    @jax.jit 
    def __call__(self, z: Float[Array, "N M"]) -> Float[Array, "O N M"]:
        return jax.vmap(
            lambda spectrum: sphere_isotropic_kernel(spectrum, z, max_ell=self.max_ell, alpha=self.alpha, sphere_dim=self.sphere_dim)
        )(self.spectrum())
    
    def prepare_inputs_full(self, x1: Float[Array, "N D"], x2: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return pairwise_dot(x1, x2)
    
    def prepare_inputs_diag(self, x: Float[Array, "N D"]) -> Float[Array, "N M"]:
        return pairwise_dot_diag(x)
    

@dataclass 
class MultioutputPrior(Module):
    kernel: MultioutputSphereMaternKernel = param_field()
    num_outputs: int = static_field(init=False)
    jitter: Float = static_field(1e-12)

    def __post_init__(self):
        self.num_outputs = self.kernel.num_outputs

    def full(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        n = x.shape[0]

        z = self.kernel.prepare_inputs_full(x, x)
        covariance = self.kernel(z)
        covariance += self.jitter * jnp.eye((self.num_outputs, n), dtype=jnp.float64)
        mean = jnp.zeros(covariance.shape[:2])
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)

    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        n = x.shape[0]

        z = self.kernel.prepare_inputs_diag(x)
        variance = self.kernel(z)
        variance += jnp.ones((self.num_outputs, n), dtype=jnp.float64) * self.jitter
        mean = jnp.zeros(variance.shape)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))

    def __call__(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        return self.full(x)
    

@dataclass 
class Prior(Module):
    kernel: SphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)

    def full(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        n = x.shape[0]

        z = self.kernel.prepare_inputs_full(x, x)
        covariance = self.kernel(z)
        covariance += self.jitter * jnp.eye(n, dtype=jnp.float64)
        mean = jnp.zeros(covariance.shape[:2])
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)

    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        n = x.shape[0]

        z = self.kernel.prepare_inputs_diag(x)
        variance = self.kernel(z)
        variance += jnp.ones(n, dtype=jnp.float64) * self.jitter
        mean = jnp.zeros(variance.shape)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))

    def __call__(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        return self.full(x)
    

@dataclass
class GaussianLikelihood(Module):
    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())

    def full(self, pf: tfd.MultivariateNormalFullCovariance) -> tfd.MultivariateNormalFullCovariance:
        covariance = pf.covariance()
        covariance += jnp.eye(covariance.shape[0], dtype=jnp.float64) * self.noise_variance
        return tfd.MultivariateNormalFullCovariance(loc=pf.mean(), covariance_matrix=covariance)
    
    def diag(self, pf: tfd.MultivariateNormalDiag) -> tfd.MultivariateNormalDiag:
        variance = pf.variance()
        variance += jnp.ones(variance.shape, dtype=jnp.float64) * self.noise_variance
        return tfd.MultivariateNormalDiag(loc=pf.mean(), scale_diag=jnp.sqrt(variance))
    
    def __call__(self, pf: tfd.MultivariateNormalFullCovariance) -> tfd.MultivariateNormalFullCovariance:
        return self.full(pf)
    

@dataclass
class Posterior(Module):
    prior: Prior = param_field()
    likelihood: GaussianLikelihood = param_field()


@dataclass
class MultioutputPosterior(Module):
    prior: MultioutputPrior = param_field()
    likelihood: GaussianLikelihood = param_field()
    num_outputs: int = static_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.prior.num_outputs


@Partial(jax.jit, static_argnames=('jitter',))
def inducing_points_moments_full(Kxx: Float, Kzx: Float, Kzz: Float, m: Float, sqrtS: Float, jitter: float = 1e-12) -> tuple[Float, Float]:
    Kzz += jnp.eye(Kzz.shape[0], dtype=jnp.float64) * jitter
    Kxx += jnp.eye(Kxx.shape[0], dtype=jnp.float64) * jitter

    Lzz = jnp.linalg.cholesky(Kzz)
    Lzz_inv_Kzx = jnp.linalg.solve(Lzz, Kzx)
    sqrtS_T_Lzz_inv_Kzx = sqrtS.T @ Lzz_inv_Kzx

    covariance = (
        Kxx
        + sqrtS_T_Lzz_inv_Kzx.T @ sqrtS_T_Lzz_inv_Kzx
        - Lzz_inv_Kzx.T @ Lzz_inv_Kzx
    )
    covariance += jnp.eye(covariance.shape[0], dtype=jnp.float64) * jitter

    mean = (
        Lzz_inv_Kzx.T @ m
    )
    return mean, covariance 


@Partial(jax.jit, static_argnames=('jitter',))
def inducing_points_moments_diag(Kxx_diag: Float, Kzx: Float, Kzz: Float, m: Float, sqrtS: Float, jitter: float = 1e-12) -> tuple[Float, Float]:
    Kzz += jnp.eye(Kzz.shape[0], dtype=jnp.float64) * jitter
    Kxx_diag += jitter

    Lzz = jnp.linalg.cholesky(Kzz)
    Lzz_inv_Kzx = jnp.linalg.solve(Lzz, Kzx)
    sqrtS_T_Lzz_inv_Kzx = sqrtS.T @ Lzz_inv_Kzx

    variance = (
        Kxx_diag
        + jnp.sum(sqrtS_T_Lzz_inv_Kzx ** 2, axis=0)
        - jnp.sum(Lzz_inv_Kzx ** 2, axis=0)
    )
    variance += jnp.ones(variance.shape, dtype=jnp.float64) * jitter

    mean = (
        Lzz_inv_Kzx.T @ m
    )
    return mean, variance


@Partial(jax.jit, static_argnames=('jitter',))
def spherical_harmonic_features_moments_full(Kxx: Float, Kxz: Float, Kzz_diag: Float, m: Float, sqrtS: Float, jitter: float = 1e-12) -> tuple[Float, Float]:
    Kxx += jnp.eye(Kxx.shape[0], dtype=jnp.float64) * jitter
    Kzz_diag += jitter

    Lzz_diag = jnp.sqrt(Kzz_diag)
    Kxz_Lzz_T_inv = Kxz / Lzz_diag
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS

    covariance = (
        Kxx
        + Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T
        # - Kxz_Lzz_T_inv @ Kxz_Lzz_T_inv.T
        # No need for the term above as it is absorbed into Kxx 
    )

    mean = (
        Kxz_Lzz_T_inv @ m
    )

    return mean, covariance


@Partial(jax.jit, static_argnames=('jitter',))
def spherical_harmonic_features_moments_diag(Kxx_diag: Float, Kxz: Float, Kzz_diag_inv: Float, m: Float, sqrtS: Float, jitter: float = 1e-12) -> tuple[Float, Float]:
    Kxx_diag += jitter

    Lzz_T_diag_inv = jnp.sqrt(Kzz_diag_inv) / jnp.sqrt(1 + jitter * Kzz_diag_inv)
    # Kzz_diag_inv = Kzz_diag_inv / (1 + jitter * Kzz_diag_inv)

    # Lzz_T_diag_inv = jnp.sqrt(Kzz_diag_inv)
    Kxz_Lzz_T_inv = Kxz * Lzz_T_diag_inv # [N M]
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS # [N M]

    variance = (
        Kxx_diag
        + jnp.sum(Kxz_Lzz_T_inv_sqrtS ** 2, axis=1)
        # - jnp.sum(Kxz_Lzz_T_inv ** 2, axis=1)
        # No need for the term above as it is absorbed into Kxx 
    )

    mean = (
        Kxz_Lzz_T_inv @ m
    )

    return mean, variance


@jax.jit
def whitened_prior_kl(m: Float, sqrtS: Float) -> Float:
    S = sqrtS @ sqrtS.T
    qz = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=S)

    pz = tfd.MultivariateNormalFullCovariance(
        loc=jnp.zeros(m.shape), 
        covariance_matrix=jnp.eye(m.shape[0]),
    )
    return tfd.kl_divergence(qz, pz)


def inducing_points_prior_kl(m: Float, sqrtS: Float) -> Float:
    return whitened_prior_kl(m, sqrtS)


@dataclass 
class DummyPosterior(Module):
    prior: Prior = param_field()


@dataclass 
class MultioutputDummyPosterior(Module):
    prior: Prior = param_field()
    num_outputs: int = static_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.prior.kernel.num_outputs


@dataclass 
class MultioutputInducingPointsPosterior(Module):
    posterior: MultioutputPosterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M O"] = param_field(init=False)
    sqrtS: Float[Array, "M M O"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_outputs: int = static_field(init=False)
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.posterior.num_outputs
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros((self.num_outputs, self.num_inducing))
        self.sqrtS = jnp.repeat(jnp.expand_dims(jnp.eye(self.num_inducing), axis=0), self.num_outputs, axis=0)

    # TODO Remove if not needed (probably isn't)
    def full(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        kernel = self.posterior.prior.kernel 
        z = self.z

        xx = kernel.prepare_inputs_full(x, x)
        zx = kernel.prepare_inputs_full(z, x)
        zz = kernel.prepare_inputs_full(z, z)

        Kxx = kernel(xx)
        Kzx = kernel(zx)
        Kzz = kernel(zz)

        m = self.m
        sqrtS = self.sqrtS

        mean, covariance = jax.vmap(inducing_points_moments_full)(Kxx, Kzx, Kzz, m, sqrtS)
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)

    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        kernel = self.posterior.prior.kernel 
        z = self.z

        xx_diag = kernel.prepare_inputs_diag(x)
        zx = kernel.prepare_inputs_full(z, x)
        zz = kernel.prepare_inputs_full(z, z)

        Kxx_diag = kernel(xx_diag)
        Kzx = kernel(zx)
        Kzz = kernel(zz)

        m = self.m
        sqrtS = self.sqrtS

        mean, variance = jax.vmap(inducing_points_moments_diag)(Kxx_diag, Kzx, Kzz, m, sqrtS)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))

    def __call__(self, x: Float[Array, "N D"], z: Float[Array, "M D"]) -> tfd.MultivariateNormalFullCovariance:
        return self.full(x, z)
    
    @jax.jit 
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(inducing_points_prior_kl)(self.m, self.sqrtS), axis=0)
    

@dataclass 
class InducingPointsPosterior(Module):
    posterior: Posterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    # TODO Remove if not needed (probably isn't)
    def full(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalFullCovariance:
        kernel = self.posterior.prior.kernel 
        z = self.z

        xx = kernel.prepare_inputs_full(x, x)
        zx = kernel.prepare_inputs_full(z, x)
        zz = kernel.prepare_inputs_full(z, z)

        Kxx = kernel(xx)
        Kzx = kernel(zx)
        Kzz = kernel(zz)

        m = self.m
        sqrtS = self.sqrtS

        mean, covariance = inducing_points_moments_full(Kxx, Kzx, Kzz, m, sqrtS, jitter=self.jitter)
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)

    @jax.jit
    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        kernel = self.posterior.prior.kernel 
        z = self.z

        xx_diag = kernel.prepare_inputs_diag(x)
        zx = kernel.prepare_inputs_full(z, x)
        zz = kernel.prepare_inputs_full(z, z)

        Kxx_diag = kernel(xx_diag)
        Kzx = kernel(zx)
        Kzz = kernel(zz)

        m = self.m
        sqrtS = self.sqrtS

        mean, variance = inducing_points_moments_diag(Kxx_diag, Kzx, Kzz, m, sqrtS, jitter=self.jitter)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))

    def __call__(self, x: Float[Array, "N D"], z: Float[Array, "M D"]) -> tfd.MultivariateNormalFullCovariance:
        return self.full(x, z)
    
    def prior_kl(self) -> Float:
        return inducing_points_prior_kl(self.m, self.sqrtS)
    

@dataclass
class SphericalHarmonicFeaturesPosterior(Module):
    posterior: Posterior = param_field()
    # spherical_harmonics: SphericalHarmonics = static_field()
    spherical_harmonics: SphericalHarmonics = param_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)

    def __post_init__(self):
        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)
        self.sqrtS_augment = jnp.ones(self.kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    @property 
    def kernel(self):
        return self.posterior.prior.kernel 

    @property 
    def num_inducing(self):
        return self.spherical_harmonics.num_phases

    @jax.jit
    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        shf = self.spherical_harmonics
        kernel = self.posterior.prior.kernel

        xx_diag = kernel.prepare_inputs_diag(x)
        spectrum = kernel.spectrum()

        # This already accounts for the subtraction of the identity matrix from S'
        S_augment = jnp.square(self.sqrtS_augment)
        Kxx_diag = sphere_isotropic_kernel(
            spectrum * S_augment, 
            xx_diag, max_ell=kernel.max_ell, alpha=kernel.alpha, sphere_dim=kernel.sphere_dim
        )


        # Variational covariance 
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases 
        Kzz_diag = jnp.repeat(
            spectrum[:shf.max_ell + 1], 
            repeats=repeats,
            total_repeat_length=total_repeat_length,
        )
        Kxz = self.spherical_harmonics.polynomial_expansion(x).T     


        m = self.m
        sqrtS = self.sqrtS

        mean, variance = spherical_harmonic_features_moments_diag(Kxx_diag, Kxz, Kzz_diag, m, sqrtS, jitter=self.jitter)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))
    
    def prior_kl(self) -> Float:
        return whitened_prior_kl(self.m, self.sqrtS)


@dataclass
class MultioutputSphericalHarmonicFeaturesPosterior(Module):
    num_outputs: int = static_field(init=False)

    posterior: MultioutputPosterior = param_field()
    spherical_harmonics: SphericalHarmonics = param_field(trainable=False)
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.posterior.num_outputs

        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)
        self.sqrtS_augment = jnp.ones(self.kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

        self.m = jnp.broadcast_to(self.m, (self.num_outputs, self.num_inducing))
        self.sqrtS = jnp.broadcast_to(self.sqrtS, (self.num_outputs, self.num_inducing, self.num_inducing))
        self.sqrtS_augment = jnp.broadcast_to(self.sqrtS_augment, (self.num_outputs, self.kernel.max_ell + 1))

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    @property 
    def num_inducing(self):
        return self.spherical_harmonics.num_phases
    
    @property 
    def kernel(self):
        return self.posterior.prior.kernel 
    
    @jax.jit
    def diag(self, x: Float[Array, "N D"]) -> tfd.MultivariateNormalDiag:
        kernel = self.posterior.prior.kernel
        shf = self.spherical_harmonics

        # prior covariance adjusted by the diagonal variational parameters 
        xx_diag = kernel.prepare_inputs_diag(x)
        spectrum = kernel.spectrum() # [O, L]
        augmented_spectrum = spectrum * jnp.square(self.sqrtS_augment) # [O, L]
        Kxx_diag = jax.vmap(
            lambda spectrum: sphere_isotropic_kernel(spectrum, xx_diag, max_ell=kernel.max_ell, alpha=kernel.alpha, sphere_dim=kernel.sphere_dim)
        )(augmented_spectrum)

        # variational covariance 
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        Kzz_diag = jax.vmap(
            lambda spectrum: jnp.repeat(spectrum, repeats=repeats, total_repeat_length=total_repeat_length)
        )(spectrum[:, :shf.max_ell + 1])
        Kxz = self.spherical_harmonics.polynomial_expansion(x).T     

        m = self.m
        sqrtS = self.sqrtS

        mean, variance = jax.vmap(
            lambda Kxx_diag, Kzz_diag, m, sqrtS: spherical_harmonic_features_moments_diag(Kxx_diag, Kxz, Kzz_diag, m, sqrtS, jitter=self.jitter)
        )(Kxx_diag, Kzz_diag, m, sqrtS)
        return tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance))

    @jax.jit
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(whitened_prior_kl)(self.m, self.sqrtS), axis=0)


@jax.jit 
def expected_log_likelihood(y: Float, m: Float, f_var: Float, eps_var: Float) -> Float:
    log2pi = jnp.log(2 * jnp.pi)
    squared_error = jnp.square(y - m)
    return -0.5 * jnp.sum(log2pi + jnp.log(eps_var) + (squared_error + f_var) / eps_var, axis=-1)


@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


# TODO verify that this is correct 
@jax.jit
def sphere_expmap(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    theta = jnp.linalg.norm(v, axis=-1, keepdims=True)

    t = x + v
    first_order_approx = t / jnp.linalg.norm(t, axis=-1, keepdims=True)
    true_expmap = jnp.cos(theta) * x + jnp.sin(theta) * v / theta

    return jnp.where(
        theta < 1e-12,
        first_order_approx,
        true_expmap,
    )


# TODO verify that this is correct
@jax.jit 
def sphere_to_tangent(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    v_x = jnp.sum(x * v, axis=-1, keepdims=True)
    return v - v_x * x


@dataclass 
class SphereResidualDeepGP(Module):
    hidden_layers: list[MultioutputInducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> Posterior:
        return self.output_layer.posterior      

    def diag_one_sample(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalDiag:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            v = v.T # [D N] -> [N D]
            v = sphere_to_tangent(x, v)
            x = sphere_expmap(x, v)
        return self.output_layer.diag(x)
    
    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalDiag:
        sample_keys = jr.split(key, self.num_samples)

        def sample_mean_stddev(key: Key) -> Float:
            px = self.diag_one_sample(x, key=key)
            return px.mean(), px.stddev()
        
        means, stddevs = jax.vmap(sample_mean_stddev)(sample_keys)
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=jnp.ones(self.num_samples) / self.num_samples),
            components_distribution=tfd.MultivariateNormalDiag(loc=means, scale_diag=stddevs),
        )
    
    def full_one_sample(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalFullCovariance:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            v = v.T
            v = sphere_to_tangent(x, v)
            x = sphere_expmap(x, v)
        return self.output_layer.full(x)
    
    def full(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalFullCovariance:
        sample_keys = jr.split(key, self.num_samples)

        def sample_mean_cov(key: Key) -> Float:
            px = self.full_one_sample(x, key=key)
            return px.mean(), px.covariance()
        
        means, covariances = jax.vmap(sample_mean_cov)(sample_keys)
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=jnp.ones(self.num_samples) / self.num_samples),
            components_distribution=tfd.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariances),
        )
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()


@dataclass
class DeepGaussianLikelihood(Module):
    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    
    def diag(self, pf: tfd.MixtureSameFamily) -> tfd.MixtureSameFamily:
        component_distribution = pf.components_distribution
        mean, variance = component_distribution.mean(), component_distribution.variance()
        variance += jnp.ones(variance.shape, dtype=jnp.float64) * self.noise_variance
        return tfd.MixtureSameFamily(
            mixture_distribution=pf.mixture_distribution,
            components_distribution=tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance)),
        )


@dataclass 
class EuclideanDeepGP(Module):
    hidden_layers: list[InducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> MultioutputPosterior:
        return self.output_layer.posterior        

    def diag_one_sample(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalDiag:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            v = v.T # [D N] -> [N D]
            x = x + v # euclidean expmap
        return self.output_layer.diag(x)
    
    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MultivariateNormalDiag:
        sample_keys = jr.split(key, self.num_samples)

        def sample_mean_stddev(key: Key) -> Float:
            px = self.diag_one_sample(x, key=key)
            return px.mean(), px.stddev()
        
        means, stddevs = jax.vmap(sample_mean_stddev)(sample_keys)
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=jnp.ones(self.num_samples) / self.num_samples),
            components_distribution=tfd.MultivariateNormalDiag(loc=means, scale_diag=stddevs),
        )
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()
    

DeepGP = TypeVar("DeepGP", EuclideanDeepGP, SphereResidualDeepGP)
    

@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


from scipy.cluster.vq import kmeans
from numpy.random import Generator


def jax_key_to_numpy_generator(key: Key) -> Generator:
    return np.random.default_rng(np.asarray(key._base_array))


def kmeans_inducing_points(key: Key, x: Float[Array, "N D"], num_inducing: int) -> Float[Array, "M D"]:
    seed = jax_key_to_numpy_generator(key)

    x = np.asarray(x, dtype=np.float64)
    n = x.shape[0]

    k_centers = []
    while num_inducing > n:
        k_centers.append(kmeans(x, n)[0])
        num_inducing -= n
    k_centers.append(kmeans(x, num_inducing, seed=seed)[0])
    k_centers = np.concatenate(k_centers, axis=0)
    return jnp.asarray(k_centers, dtype=jnp.float64)

def create_residual_deep_gp_with_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = False, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    hidden_z = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
    output_z = hidden_z

    if kernel_max_ell is None:
        kernel_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim)

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z) # TODO set z to be trainable maybe 
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def num_phases_to_num_levels(num_phases: int, *, sphere_dim: int) -> int:
    l = 0
    while num_phases > 0:
        num_phases -= num_phases_in_frequency(frequency=l, sphere_dim=sphere_dim)
        l += 1
    return l - 1 if num_phases == 0 else l - 2


def create_residual_deep_gp_with_spherical_harmonic_features(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    shf_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim)
    if kernel_max_ell is None:
        kernel_max_ell = shf_max_ell
    hidden_spherical_harmonics = SphericalHarmonics(max_ell=shf_max_ell, sphere_dim=sphere_dim)
    output_spherical_harmonics = hidden_spherical_harmonics

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputSphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=hidden_spherical_harmonics)
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_euclidean_deep_gp_with_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = True
) -> EuclideanDeepGP:
    sphere_dim = x.shape[1] - 1
    
    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.ones(sphere_dim + 1)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    hidden_z = z
    output_z = z

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputEuclideanMaternKernel32(
            variance=hidden_variance,
            kappa=hidden_kappa, 
            num_outputs=sphere_dim + 1, 
            num_inputs=sphere_dim + 1,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z)
        if train_inducing:
            layer = layer.replace_trainable(z=True)
        hidden_layers.append(layer)


    kernel = EuclideanMaternKernel32(
        num_inputs=sphere_dim + 1,
        variance=output_variance,
        kappa=output_kappa,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return EuclideanDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_euclidean_deep_gp_with_input_geometric_layer_and_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = True, 
) -> EuclideanDeepGP:
    if num_layers == 1:
        return create_residual_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    
    sphere_dim = x.shape[1] - 1

    input_nu = jnp.array(1.5)

    input_variance = total_hidden_variance / max(num_layers - 1, 1)
    hidden_variance = input_variance
    output_variance = jnp.array(1.0)

    input_kappa = jnp.array(1.0)
    hidden_kappa = jnp.ones(sphere_dim + 1)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    input_z = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
    hidden_z = z
    output_z = z


    hidden_layers = []

    # input layer
    kernel = MultioutputSphereMaternKernel(
        num_outputs=sphere_dim + 1, 
        sphere_dim=sphere_dim, 
        variance=input_variance,
        kappa=input_kappa,
        nu=input_nu,
    )
    prior = MultioutputPrior(kernel=kernel)
    posterior = MultioutputDummyPosterior(prior=prior)
    layer = MultioutputInducingPointsPosterior(posterior=posterior, z=input_z)
    hidden_layers.append(layer)

    for _ in range(num_layers - 2):
        kernel = MultioutputEuclideanMaternKernel32(
            kappa=hidden_kappa,
            variance=hidden_variance, 
            num_inputs=sphere_dim + 1,
            num_outputs=sphere_dim + 1,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z)
        if train_inducing:
            layer = layer.replace_trainable(z=True)
        hidden_layers.append(layer)

    kernel = EuclideanMaternKernel32(
        kappa=output_kappa,
        variance=output_variance,
        num_inputs=sphere_dim + 1,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return EuclideanDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_model(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, name: str, nu: float = 1.5, 
    kernel_max_ell: int | None = None
) -> DeepGP: 
    if name == 'residual+inducing_points':
        return create_residual_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        ) 
    if name == 'euclidean+inducing_points':
        return create_euclidean_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'euclidean_with_geometric_input+inducing_points':
        return create_euclidean_deep_gp_with_input_geometric_layer_and_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'residual+spherical_harmonic_features':
        return create_residual_deep_gp_with_spherical_harmonic_features(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        )
    if name == 'residual+hodge+spherical_harmonic_features':
        return create_hodge_residual_deep_gp_with_spherical_harmonic_features(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        )
    raise ValueError(f"Unknown model name: {name}")



# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
import jax.random as jr
import optax as ox

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    x: Float, 
    y: Float,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = -1,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
) -> Tuple[ModuleModel, Array]:
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(model: Module, x: Float, y: Float, *, key: Key) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(model.constrain(), x, y, key=key)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size != -1:
            batch_x, batch_y = get_batch(x, y, batch_size, key)
        else:
            batch_x, batch_y = x, y

        loss_val, loss_gradient = jax.value_and_grad(loss)(model, batch_x, batch_y, key=key)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)

    # Constrained space.
    model = model.constrain()

    return model, history


def get_batch(x: Float, y: Float, batch_size: int, key: KeyArray) -> tuple[Float, Float]:
    """Batch the data into mini-batches. Sampling is done with replacement.

    Args:
        train_data (Dataset): The training dataset.
        batch_size (int): The batch size.
        key (KeyArray): The random key to use for the batch selection.

    Returns
    -------
        Dataset: The batched dataset.
    """
    n = x.shape[0]

    # Subsample mini-batch indices with replacement.
    indices = jr.choice(key, n, (batch_size,), replace=True)

    return x[indices], y[indices]


@jax.jit 
def negative_elbo_ignore_key(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


@jax.jit 
def deep_negative_elbo(p: EuclideanDeepGP, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    pf_components = pf_diag.components_distribution
    m, f_var = pf_components.mean(), pf_components.variance()
    eps_var = p.posterior.likelihood.noise_variance

    def deep_expected_log_likelihood() -> Float:
        return jnp.mean(jax.vmap(
            lambda m_sample, f_var_sample: expected_log_likelihood(y, m_sample, f_var_sample, eps_var)
        )(m, f_var), axis=0)
    
    return -(deep_expected_log_likelihood() - p.prior_kl())


import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from scipy.special import sph_harm


def sphere_meshgrid(n_colat=100, n_lon=100, epsilon=1e-32):
    colat = np.linspace(0 + epsilon, np.pi - epsilon, n_colat)
    lon = np.linspace(0, 2 * np.pi, n_lon)
    lon, colat = np.meshgrid(lon, colat)
    return colat, lon


def target_f__product_of_sines(sph: Float) -> Float:
    """
    Irregular function with singularities.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    return jnp.sin(colat * 3) * jnp.sin(lon * 3)


def target_f__sum_of_sines(sph: Float) -> Float:
    colat, lon = sph[..., 0], sph[..., 1]
    return (np.abs(np.sin(colat * 2)) + np.abs(np.sin(lon * 2)) - 1.0) / 2



def rotate(sph, roll: float):
    """
    Apply a roll rotation to the points in spherical coordinates (colatitude, longitude) in [0, pi] x [0, 2pi].
    
    Parameters:
    colatitude (float or jnp.ndarray): Colatitude values in the range [0, pi].
    longitude (float or jnp.ndarray): Longitude values in the range [0, 2pi].
    angle (float): Angle of rotation in radians.
    
    Returns:
    jnp.ndarray: Array of transformed colatitude and longitude values.
    """
    colatitude, longitude = sph[..., 0], sph[..., 1]

    # Convert to Cartesian coordinates
    x = jnp.sin(colatitude) * jnp.cos(longitude)
    y = jnp.sin(colatitude) * jnp.sin(longitude)
    z = jnp.cos(colatitude)
    
    # Apply roll rotation
    new_x = x
    new_y = jnp.cos(roll) * y - jnp.sin(roll) * z
    new_z = jnp.sin(roll) * y + jnp.cos(roll) * z
    
    # Convert back to spherical coordinates
    new_colatitude = jnp.arccos(new_z / jnp.sqrt(new_x**2 + new_y**2 + new_z**2))
    new_longitude = jnp.arctan2(new_y, new_x)
    
    return jnp.stack([new_colatitude, new_longitude], axis=-1)


def reversed_spherical_harmonic(sph, m: int, n: int):
    colat, lon = sph[..., 0], sph[..., 1]
    return jnp.asarray(sph_harm(m, n, np.asarray(colat), np.asarray(lon)).real)


def target_f__reversed_spherical_harmonic(sph: Float) -> Float:
    return reversed_spherical_harmonic(sph, m=1, n=2) + reversed_spherical_harmonic(rotate(sph, roll=jnp.pi / 2), m=1, n=1)


def target_f(sph: Float[Array, " N D"], *, name: str = "product_of_sines") -> Float[Array, " N"]:
    if name == "product_of_sines":
        return target_f__product_of_sines(sph)
    if name == "sum_of_sines":
        return target_f__sum_of_sines(sph)
    if name == "reversed_spherical_harmonic":
        return target_f__reversed_spherical_harmonic(sph)
    raise ValueError(f"Unknown target function: {name}")


def add_noise(f: Float[Array, " N"], noise_std: float = 0.01, *, key: Key) -> Float:
    return f + jax.random.normal(key=key, shape=f.shape) * noise_std


def evaluate_diag(model: DeepGP, x: Float[Array, "N D"], y: Float[Array, " N"], *, key: Key) -> dict[str, float]:
    py = model.posterior.likelihood.diag(model.diag(x, key=key))
    mean, stddev = py.mean(), py.stddev()
    metrics = {
        "mse": mse(y, mean).item(),
        "nlpd": nlpd(y, mean, stddev).item(),
    }
    return metrics 


def evaluate_joint(model: DeepGP, x: Float[Array, "N D"], y: Float[Array, " N"], *, key: Key) -> dict[str, float]:
    pf = model.full(x, key=key)
    pf_components = pf.components_distribution
    m, f_covar = pf_components.mean(), pf_components.covariance()
    f_covar += model.posterior.likelihood.noise_variance * jnp.eye(f_covar.shape[-1])
    
    py_batched = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=f_covar)
    py = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=jnp.zeros(model.num_samples)),
        components_distribution=py_batched,
    )

    mean = py.mean()
    metrics = {
        "mse": mse(y, mean).item(),
        "nlpd": (py.log_prob(y) / y.shape[0]).item(),
    }
    return metrics 


import abc 
import cola
from typing import TypeVar 
from jaxtyping import Float, Num
from gpjax.typing import ScalarFloat


Kernel = TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel")
    

@jax.jit
def tangent_basis_normalization_matrix(x: Float[Array, "2"]) -> Float[Array, "2 2"]:
    return jnp.array([
        [1.0, 0.0], 
        [0.0, 1.0 / jnp.sin(x[0])],
    ])


hodge_star_matrix = jnp.array([
    [0.0, 1.0],
    [-1.0, 0.0],
])


@Partial(jax.jit, static_argnames=('min_value', ))
def clip_colatitude(x: Float[Array, "2"], min_value: float) -> Float[Array, "2"]:
    return jax.lax.cond(
        x[0] < min_value,
        lambda: x.at[0].set(min_value),
        lambda: x,
    )


@Partial(jax.jit, static_argnames=("max_ell", "alpha"))
def hodge_gegenbauer(t1: Float[Array, "2"], t2: Float[Array, "2"], max_ell: int, alpha: float):
    Nt1 = tangent_basis_normalization_matrix(t1)
    Nt2 = tangent_basis_normalization_matrix(t2)
    dd = jax.jacfwd(jax.jacfwd(lambda x, y: sph_gegenbauer(x, y, max_ell=max_ell, alpha=alpha)[1:], argnums=0), argnums=1)(t1, t2)
    return Nt1.T @ dd @ Nt2
    

@Partial(jax.jit, static_argnames=('max_ell', ))
def hodge_sphere_addition_theorem(x: Float[Array, "2"], y: Float[Array, "2"], max_ell: int) -> Float[Array, "I 2 2"]:
    alpha = (2.0 - 1.0) / 2.0
    ells = jnp.arange(1, max_ell + 1, dtype=jnp.float64)
    gegenbauer_at_0 = 2 * ells + 1
    lambda_ells = ells * (ells + 1)
    ddPz = hodge_gegenbauer(x, y, max_ell, alpha)
    return (gegenbauer_at_0 / lambda_ells)[:, None, None] * ddPz


@jax.jit
def addition_theorem_matrix_kernel(spectrum: Float[Array, "I"], z: Float[Array, "I 2 2"]) -> Float[Array, "2 2"]:
    return jnp.sum(spectrum[:, None, None] * z, axis=0)


@dataclass 
class AbstractHodgeKernel(Module):
    nu: Float[Array, "1"] = param_field(jnp.array(2.5), bijector=tfp.bijectors.Softplus())
    kappa: Float[Array, "1"] = param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())
    variance: Float[Array, "1"] = param_field(jnp.array(1.0), bijector=tfp.bijectors.Softplus())
    alpha: float = static_field(0.5)
    max_ell: int = static_field(10)
    colatitude_min_value: float = static_field(1e-12) 
    sphere_dim: int = static_field(2)

    @property 
    def ells(self):
        return jnp.arange(1, self.max_ell + 1, dtype=jnp.float64)
    
    def spectrum(self) -> Float[Array, "I"]:
        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)
    
    @jax.jit
    def validate_inputs(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> tuple[Float[Array, "2"], Float[Array, "2"]]:
        x = clip_colatitude(x, self.colatitude_min_value)
        y = clip_colatitude(y, self.colatitude_min_value)
        return x, y

    @jax.jit 
    def prepare_inputs(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> tuple[Float[Array, "2"], Float[Array, "2"]]:
        x, y = self.validate_inputs(x, y)
        return hodge_sphere_addition_theorem(x, y, self.max_ell)
    

@dataclass
class HodgeMaternCurlFreeKernel(AbstractHodgeKernel):

    @jax.jit 
    def from_addition_theorem(self, z: Float[Array, "2 2"]) -> Float[Array, "2 2"]:
        return addition_theorem_matrix_kernel(self.spectrum(), z)

    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        z = self.prepare_inputs(x, y)
        return self.from_addition_theorem(z)


@dataclass
class HodgeMaternDivFreeKernel(AbstractHodgeKernel):

    @jax.jit 
    def from_addition_theorem(self, z: Float[Array, "2 2"]) -> Float[Array, "2 2"]:
        dd = addition_theorem_matrix_kernel(self.spectrum(), z)
        H = hodge_star_matrix # [2 2]
        return H.T @ dd @ H

    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        z = self.prepare_inputs(x, y)
        return self.from_addition_theorem(z)


@dataclass
class HodgeMaternKernel(Module):
    kappa: InitVar[ScalarFloat] = 1.0
    nu: InitVar[ScalarFloat] = 2.5
    variance: InitVar[ScalarFloat] = 1.0

    max_ell: int = static_field(10)
    curl_free_kernel: HodgeMaternCurlFreeKernel = param_field(init=False)
    div_free_kernel: HodgeMaternDivFreeKernel = param_field(init=False)
    colatitude_min_value: float = static_field(1e-12)

    def __post_init__(self, kappa, nu, variance):
        self.curl_free_kernel = HodgeMaternCurlFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=self.colatitude_min_value)
        self.div_free_kernel = HodgeMaternDivFreeKernel(kappa=kappa, nu=nu, variance=variance, max_ell=self.max_ell, colatitude_min_value=self.colatitude_min_value)

    @jax.jit 
    def from_addition_theorem(self, z: Float[Array, "2 2"]) -> Float[Array, "2 2"]:
        return self.curl_free_kernel.from_addition_theorem(z) + self.div_free_kernel.from_addition_theorem(z)

    def spectrum(self):
        return jnp.concat([self.curl_free_kernel.spectrum(), self.div_free_kernel.spectrum()])

    def validate_inputs(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> tuple[Float[Array, "2"], Float[Array, "2"]]:
        x = clip_colatitude(x, self.colatitude_min_value)
        y = clip_colatitude(y, self.colatitude_min_value)
        return x, y
    
    def __call__(self, x: Float[Array, "2"], y: Float[Array, "2"]) -> Float[Array, "2 2"]:
        x, y = self.validate_inputs(x, y)
        z = hodge_sphere_addition_theorem(x, y, self.max_ell)
        return self.from_addition_theorem(z)


@dataclass 
class AbstractSphericalHarmonicFields(gpjax.Module):
    max_ell: int = static_field(10)
    sphere_dim: int = static_field(2)
    colatitude_min_value: float = static_field(1e-12)
    spherical_harmonics: SphericalHarmonics = static_field(init=False)
    num_phases_per_frequency: Float[Array, "L"] = static_field(init=False)
    num_phases: int = static_field(init=False)
    num_fields: int = static_field(init=False)

    def __post_init__(self) -> None:
        self.spherical_harmonics = SphericalHarmonics(max_ell=self.max_ell, sphere_dim=self.sphere_dim)
        num_phases_per_frequency = self.spherical_harmonics.num_phases_per_frequency[1:]
        self.num_phases_per_frequency = jnp.array(num_phases_per_frequency)
        self.num_phases = sum(num_phases_per_frequency)

    @jax.jit
    def _sph_polynomial_expansion(self, x: Float[Array, "2"]) -> Float[Array, "2"]:
        ells = jnp.arange(1, self.max_ell + 1)
        lambda_ells = ells * (ells + 1) 
        normalization_term = jnp.repeat(
            jnp.sqrt(lambda_ells),
            self.num_phases_per_frequency,
            total_repeat_length=self.num_phases
        )
        return self.spherical_harmonics.polynomial_expansion(sph_to_car(x))[1:] / normalization_term

    @jax.jit
    def eigenfields(self, x: Float[Array, "2"]) -> Float[Array, "2"]:
        x = clip_colatitude(x, self.colatitude_min_value)
        Nx = tangent_basis_normalization_matrix(x)
        return jax.jacfwd(self._sph_polynomial_expansion)(x) @ Nx

    @abstractmethod
    def __call__(self, x: Float[Array, "N 2"]) -> Float[Array, "N 2"]:
        pass 
    
    def __eq__(self, other: "AbstractSphericalHarmonicFields") -> bool:
        return self.max_ell == other.max_ell and self.sphere_dim == other.sphere_dim and self.colatitude_min_value == other.colatitude_min_value
    

@dataclass 
class SphericalHarmonicFields(AbstractSphericalHarmonicFields):

    def __post_init__(self) -> None:
        super().__post_init__()
        self.num_fields = 2 * self.num_phases

    @jax.jit
    def __call__(self, x: Float[Array, "2"]) -> Float[Array, "2I 2"]:
        """
        Returns curl-free and divergence-free fields concatenated.
        """
        H = hodge_star_matrix
        v = self.eigenfields(x) # [I 2]
        return jnp.concat([v, v @ H], axis=-2)
    

@dataclass 
class HodgePrior(Module):
    kernel: HodgeMaternKernel = param_field()
    jitter: float = static_field(1e-12)
    

@dataclass 
class HodgePosterior(Module):
    prior: HodgePrior = param_field()
    likelihood: GaussianLikelihood = param_field()


@Partial(jax.jit, static_argnames=("jitter", ))
def spherical_harmonic_fields_moments(
    Kxz: Float[Array, "2 I"], Kzz_diag: Float[Array, "I"], m: Float[Array, "2"], sqrtS: Float[Array, "I I"], jitter: float
) -> tuple[Float[Array, "I"], Float[Array, "I"]]:
    # Kxz @ K
    Lzz_T_inv = jnp.sqrt(Kzz_diag / (1 + Kzz_diag * jitter))

    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv 
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS

    covariance = Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T
    covariance = covariance.at[jnp.diag_indices_from(covariance)].add(jitter)

    mean = Kxz_Lzz_T_inv @ m
    return mean, covariance

    
@dataclass 
class SphericalHarmonicFieldsPosterior(Module):
    posterior: HodgePosterior = param_field()
    spherical_harmonic_fields: SphericalHarmonicFields = param_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())

    def __post_init__(self):
        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    @property 
    def kernel(self):
        return self.posterior.prior.kernel 

    @property 
    def num_inducing(self):
        return self.spherical_harmonic_fields.num_fields

    def prior_kl(self) -> Float:
        return whitened_prior_kl(self.m, self.sqrtS)
    
    # FIXME This is the main ugly part of the code. It would be nice to refactor, but it's not a priority.
    @jax.jit 
    def Kzz_diag(self) -> Float[Array, "I"]:
        curl_free_kernel = self.kernel.curl_free_kernel
        div_free_kernel = self.kernel.div_free_kernel

        curl_free_ahats_per_frequency = curl_free_kernel.spectrum()[:self.spherical_harmonic_fields.max_ell]
        div_free_ahats_per_frequency = div_free_kernel.spectrum()[:self.spherical_harmonic_fields.max_ell]

        def repeat_per_phase(x):
            return jnp.repeat(
                x, 
                self.spherical_harmonic_fields.num_phases_per_frequency,
                total_repeat_length=self.spherical_harmonic_fields.num_phases,
            )

        return jnp.concatenate([
            repeat_per_phase(curl_free_ahats_per_frequency),
            repeat_per_phase(div_free_ahats_per_frequency),
        ]) 
    
    def Kxz(self, x: Float[Array, "2"]) -> Float[Array, "2 I"]:
        return self.spherical_harmonic_fields(x).T 
    
    @jax.jit 
    def __call__(self, x: Float[Array, "2"]) -> tfd.MultivariateNormalFullCovariance:
        mean, covariance = spherical_harmonic_fields_moments(
            self.Kxz(x), self.Kzz_diag(), self.m, self.sqrtS, jitter=self.jitter
        )
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N 2"]) -> tfd.MultivariateNormalFullCovariance:
        def moments(t):
            pt = self.__call__(t)
            return pt.mean(), pt.covariance()
        
        mean, covariance = jax.vmap(moments)(x)
        return tfd.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)
    

EPS = 1e-12


@jax.jit
def tangent_basis(x: Float[Array, "3"]) -> Float[Array, "3"]:
    tb = jax.jacfwd(sph_to_car)(x)
    tb /= jnp.linalg.norm(tb, axis=0, keepdims=True)
    return tb 


@jax.jit
def expmap_car(x: Float[Array, "3"], v: Float[Array, "3"]) -> Float[Array, "3"]:
    def first_order_taylor():
        t = x + v 
        return t / jnp.linalg.norm(t)

    theta = jnp.linalg.norm(v)
    return jax.lax.cond(
        theta < EPS,
        first_order_taylor,
        lambda: jnp.cos(theta) * x + jnp.sin(theta) * v / theta,
    )


@Partial(jax.jit, static_argnames=("colatitude_min_value", ))
def expmap_sph(x: Float[Array, "2"], v: Float[Array, "2"], colatitude_min_value: float = EPS) -> Float[Array, "2"]:
    """
    Exponential map on the sphere taking x in spherical coordinates and v in the 'canonical' coordinate frame. 
    This function internally ensures that the colatitude of x is not too small to avoid nans.
    """
    x = clip_colatitude(x, colatitude_min_value)
    x_prime = sph_to_car(x)
    v_prime = tangent_basis(x) @ v
    return car_to_sph(expmap_car(x_prime, v_prime))


@dataclass 
class HodgeResidualDeepGP(Module):
    hidden_layers: list[SphericalHarmonicFieldsPosterior] = param_field()
    output_layer: SphericalHarmonicFeaturesPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self):
        return self.output_layer.posterior

    def diag_one_sample(self, x: Float[Array, "N 2"], *, key: Key) -> tfd.MultivariateNormalFullCovariance:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            x = jax.vmap(expmap_sph, in_axes=(0, 0))(x, v)
        return self.output_layer.diag(sph_to_car(x)) # last layer need cartesian coordinates 
    
    def diag(self, x: Float[Array, "N 2"], *, key: Key) -> tfd.MultivariateNormalFullCovariance:
        sample_keys = jr.split(key, self.num_samples)

        def sample_mean_stddev(key: Key) -> Float:
            px = self.diag_one_sample(x, key=key)
            return px.mean(), px.stddev()
        
        means, stddevs = jax.vmap(sample_mean_stddev)(sample_keys)
        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=jnp.ones(self.num_samples) / self.num_samples),
            components_distribution=tfd.MultivariateNormalDiag(loc=means, scale_diag=stddevs),
        )
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()
    

def create_hodge_residual_deep_gp_with_spherical_harmonic_features(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1]

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    shf_hidden_max_ell = num_phases_to_num_levels(num_inducing // 2, sphere_dim=sphere_dim)
    # shf_output_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim) TODO temporary fix 
    shf_output_max_ell = num_phases_to_num_levels(49, sphere_dim=sphere_dim)
    if kernel_max_ell is None:
        hidden_kernel_max_ell = shf_hidden_max_ell
        output_kernel_max_ell = shf_output_max_ell
    hidden_spherical_harmonic_fields = SphericalHarmonicFields(max_ell=shf_hidden_max_ell)
    output_spherical_harmonics = SphericalHarmonics(max_ell=shf_output_max_ell, sphere_dim=sphere_dim)

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = HodgeMaternKernel(
            kappa=hidden_kappa,
            nu=hidden_nu,
            variance=hidden_variance, 
            max_ell=hidden_kernel_max_ell,
        )
        prior = Prior(kernel=kernel)
        posterior = DummyPosterior(prior=prior)
        layer = SphericalHarmonicFieldsPosterior(posterior=posterior, spherical_harmonic_fields=hidden_spherical_harmonic_fields)
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=output_kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)

    return HodgeResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)



num_test = 5000
num_plot = 100


total_hidden_variance = 0.0001
num_inducing = 50
train_num_samples = 3
test_num_samples = 10

lr = 0.01
num_iters = 1000


def train_and_eval(
    target_function_name: str, num_train: int, model_name: str, num_layers: int, seed: int,
    num_inducing: int = 50, kernel_max_ell: int | None = None, 
): 
    key = jax.random.key(seed)
    key, data_key = jax.random.split(key)

    # data 
    train_x = sphere_uniform_grid(num_train)
    test_x = sphere_uniform_grid(num_test)
    plot_sph = jnp.stack(sphere_meshgrid(num_plot, num_plot), axis=-1).reshape(-1, 2)
    plot_x = sph_to_car(plot_sph)

    target_function = lambda x: target_f(car_to_sph(x), name=target_function_name)
    train_key, test_key, plot_key = jax.random.split(data_key, 3)
    train_y = add_noise(target_function(train_x), key=train_key)
    test_y = add_noise(target_function(test_x), key=test_key)
    plot_y = add_noise(target_function(plot_x), key=plot_key).reshape(num_plot, num_plot)

    if "hodge" in model_name:
        train_x = car_to_sph(train_x)
        test_x = car_to_sph(test_x)
        plot_x = car_to_sph(plot_x)

    # model
    key, model_key = jax.random.split(key)
    model = create_model(
        num_layers=num_layers, total_hidden_variance=total_hidden_variance, 
        num_inducing=num_inducing, x=train_x, num_samples=train_num_samples, 
        name=model_name, key=model_key, kernel_max_ell=kernel_max_ell
    )

    # training
    optim = ox.adam(learning_rate=lr)
    model_opt, history = fit(
        model=model, 
        objective=deep_negative_elbo, 
        x=train_x, 
        y=train_y, 
        optim=optim, 
        key=key, 
        num_iters=num_iters,
    )


    # testing
    model_opt = model_opt.replace(num_samples=test_num_samples)
    metrics_diag = evaluate_diag(model_opt, test_x, test_y, key=key)


    # plotting 
    # true function, mean, variance, error, history
    plot_sph = plot_sph.reshape(num_plot, num_plot, 2)
    pplot = model_opt.diag(plot_x, key=key)
    mean, stddev = pplot.mean(), pplot.stddev()
    mean, stddev = mean.reshape(num_plot, num_plot), stddev.reshape(num_plot, num_plot)

    fig = plot_predictions_and_error(plot_sph, plot_y, mean, stddev)

    return model_opt, history, metrics_diag, fig


def plot_predictions_and_error(plot_sph: Float[Array, "N D"], y: Float[Array, "N D"], mean: Float[Array, "N D"], stddev: Float[Array, "N D"]):
    nrows = 2
    ncols = 2
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 4), layout="constrained")


    lon, colat = plot_sph[..., 1], plot_sph[..., 0]
    
    # top left, predictive mean
    c = axs[0][0].pcolormesh(lon, colat, mean, vmin=jnp.min(mean), vmax=jnp.max(mean))
    fig.colorbar(c, ax=axs[0][0])
    axs[0][0].set_xlabel("lon")
    axs[0][0].set_ylabel("lat")
    axs[0][0].set_title("Predictive Mean")

    # top right, ground truth
    c = axs[0][1].pcolormesh(lon, colat, y, vmin=jnp.min(y), vmax=jnp.max(y))
    fig.colorbar(c, ax=axs[0][1])
    axs[0][1].set_xlabel("lon")
    axs[0][1].set_ylabel("lat")
    axs[0][1].set_title("Ground Truth")

    # bottom left, difference
    rse = jnp.sqrt(jnp.square(y - mean))
    c = axs[1][0].pcolormesh(lon, colat, rse, vmin=jnp.min(rse), vmax=jnp.max(rse))
    fig.colorbar(c, ax=axs[1][0])
    axs[1][0].set_xlabel("lon")
    axs[1][0].set_ylabel("lat")
    axs[1][0].set_title("Root Squared Error")

    # bottom right, stddev
    c = axs[1][1].pcolormesh(lon, colat, stddev, vmin=jnp.min(stddev), vmax=jnp.max(stddev))
    fig.colorbar(c, ax=axs[1][1])
    axs[1][1].set_xlabel("lon")
    axs[1][1].set_ylabel("lat")
    axs[1][1].set_title("Predictive Uncertainty (Standard Deviation)")

    return fig


import os 


def save_results(experiment_settings: dict[str, Any], metrics: dict[str, float], fig: plt.Figure, *, dir_path: str):
    experiment_name = "-".join(f"{k}={v}" for k, v in experiment_settings.items())
    experiment_dir = os.path.join(dir_path, experiment_name)

    os.makedirs(experiment_dir, exist_ok=True)

    # save metrics
    metrics_file_name = os.path.join(experiment_dir, "metrics.csv")
    pd.DataFrame(metrics, index=[0]).to_csv(metrics_file_name)

    # save figure 
    fig_file_name = os.path.join(experiment_dir, "plot.pdf")
    fig.savefig(fig_file_name, bbox_inches="tight")
    return 


def main(target_function_name, num_train, model_name, num_layers, seed, num_inducing: int = 50, kernel_max_ell: int | None = None):
    experiment_settings = {
        "target_function_name": target_function_name,
        "num_train": num_train,
        "model_name": model_name,
        "num_layers": num_layers,
        "num_inducing": num_inducing,
        "seed": seed,
        "kernel_max_ell": kernel_max_ell,
    }
    print(experiment_settings)

    model_opt, history, metrics, fig = train_and_eval(
        target_function_name, num_train, model_name, num_layers, seed, num_inducing=num_inducing, kernel_max_ell=kernel_max_ell
    )
    print(metrics)
    save_results(
        experiment_settings=experiment_settings,
        metrics=metrics,
        fig=fig, 
        dir_path="results/synthetic",
    )


import argparse 


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--target_function_name", type=str, required=False, default="reversed_spherical_harmonic")
    parser.add_argument("--num_train", type=int, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--num_layers", type=int, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--num_inducing", type=int, default=49)
    parser.add_argument("--kernel_max_ell", type=int, default=None)
    args = parser.parse_args()

    main(
        target_function_name=args.target_function_name,
        num_train=args.num_train,
        model_name=args.model_name,
        num_layers=args.num_layers,
        seed=args.seed,
        num_inducing=args.num_inducing,
        kernel_max_ell=args.kernel_max_ell,
    )