from jax import config 
config.update("jax_enable_x64", True)


import jax 
import jax.numpy as jnp
import gpjax 
import optax 
from gpjax.distributions import GaussianDistribution
from gpjax.typing import Float
from gpjax.base import static_field, param_field
from jax.tree_util import Partial
import tensorflow_probability.substrates.jax as tfp
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Key

from matplotlib import pyplot as plt 

from dataclasses import dataclass, InitVar
from abc import abstractmethod


@jax.jit
def sph_dot_product(sph1: Float, sph2: Float) -> Float:
    """
    Computes dot product in R^3 of two points on the sphere in spherical coordinates.
    """
    colat1, lon1 = sph1[..., 0], sph1[..., 1]
    colat2, lon2 = sph2[..., 0], sph2[..., 1]
    return jnp.sin(colat1) * jnp.sin(colat2) * jnp.cos(lon1 - lon2) + jnp.cos(colat1) * jnp.cos(colat2)


# NOTE jitting this doesn't help
def sph_gegenbauer(x, y, max_ell: int, alpha: float = 0.5):
    return gegenbauer(x=sph_dot_product(x, y), max_ell=max_ell, alpha=alpha)


# NOTE jitting this doesn't help
def sph_gegenbauer_single(x, y, ell: int, alpha: float = 0.5):
    return gegenbauer_single(x=sph_dot_product(x, y), ell=ell, alpha=alpha)


def array(x):
    return jnp.array(x, dtype=jnp.float64)


@jax.jit
def sph_to_car(sph):
    """
    From spherical (colat, lon) coordinates to cartesian, single point.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    lon = (lon + 2 * jnp.pi) % (2 * jnp.pi)
    return jnp.stack([colat, lon], axis=-1)



from pathlib import Path
from typing import Callable

import numpy as np
from jax import Array


class FundamentalSystemNotPrecomputedError(ValueError):

    def __init__(self, dimension: int):
        message = f"Fundamental system for dimension {dimension} has not been precomputed."
        super().__init__(message)


def fundamental_set_loader(dimension: int, load_dir="fundamental_system") -> Callable[[int], Array]:
    load_dir = Path("../") / load_dir
    file_name = load_dir / f"fs_{dimension}D.npz"

    cache = {}
    if file_name.exists():
        with np.load(file_name) as f:
            cache = {k: v for (k, v) in f.items()}
    else:
        raise FundamentalSystemNotPrecomputedError(dimension)

    def load(degree: int) -> Array:
        key = f"degree_{degree}"
        if key not in cache:
            raise ValueError(f"key: {key} not in cache.")
        return cache[key]

    return load


@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))
def gegenbauer(x: Float[Array, "N D"], max_ell: int, alpha: float = 0.5) -> Float[Array, "N L"]:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x
    
    res = jnp.empty((*x.shape, max_ell + 1), dtype=x.dtype)
    res = res.at[..., 0].set(C_0)

    def step(n: int, res_and_Cs: tuple[Float, Float, Float]) -> tuple[Float, Float, Float]:
        res, C, C_prev = res_and_Cs
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        res = res.at[..., n].set(C)
        return res, C, C_prev
    
    return jax.lax.cond(
        max_ell == 0,
        lambda: res,
        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[..., 1].set(C_1), C_1, C_0))[0],
    )


@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call 
def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x

    def step(Cs_and_n):
        C, C_prev, n = Cs_and_n
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        return C, C_prev, n + 1

    def cond(Cs_and_n):
        n = Cs_and_n[2]
        return n <= ell

    return jax.lax.cond(
        ell == 0,
        lambda: C_0,
        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],
    )


@dataclass
class SphericalHarmonics(gpjax.Module):
    """
    Spherical harmonics inducing features for sparse inference in Gaussian processes.

    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,
    as such, they form an orthogonal basis.

    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute
    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via
    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during
    initialisation.

    Attributes:
        num_frequencies: The number of frequencies, up to which, we compute the harmonics.

    Returns:
        An instance of the spherical harmonics features.
    """

    max_ell: int = static_field()
    sphere_dim: int = static_field()
    alpha: float = static_field(init=False)
    orth_basis: Array = static_field(init=False)
    Vs: list[Array] = static_field(init=False)
    num_phases_per_frequency: Float[Array, " L"] = static_field(init=False)
    num_phases: int = static_field(init=False)


    @property
    def levels(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)
    

    def __post_init__(self) -> None:
        """
        Initialise the parameters of the spherical harmonic features and return a `Param` object.

        Returns:
            None
        """
        dim = self.sphere_dim + 1

        # Try loading a pre-computed fundamental set.
        fund_set = fundamental_set_loader(dim)

        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.
        self.alpha = (dim - 2.0) / 2.0

        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.
        self.Vs = [fund_set(n) for n in self.levels]

        # pre-compute and save the orthogonal basis 
        self.orth_basis = self._orthogonalise_basis()


        # set these things instead of computing every time 
        self.num_phases_per_frequency = [v.shape[0] for v in self.Vs]
        self.num_phases = sum(self.num_phases_per_frequency)


    @property
    def Ls(self) -> list[Array]:
        """
        Alias for the orthogonal basis at every frequency.
        """
        return self.orth_basis

    def _orthogonalise_basis(self) -> None:
        """
        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.
        """
        alpha = self.alpha
        levels = jnp.split(self.levels, self.max_ell + 1)
        const = alpha / (alpha + self.levels.astype(jnp.float64))
        const = jnp.split(const, self.max_ell + 1)

        def _func(v, n, c):
            x = jnp.matmul(v, v.T)
            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)
            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))
            return L

        return jax.tree.map(_func, self.Vs, levels, const)

    def custom_gegenbauer_single(self, x, ell, alpha):
        return gegenbauer(x, self.max_ell, alpha)[..., ell]

    @jax.jit
    def polynomial_expansion(self, X: Float[Array, "N D"]) -> Float[Array, "M N"]:
        """
        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.

        Args:
            X: Input Array.

        Returns:
            The harmonics evaluated at the input as a polynomial expansion of the basis.
        """
        levels = jnp.split(self.levels, self.max_ell + 1)

        def _func(v, n, L):
            vxT = jnp.dot(v, X.T)
            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)
            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)
            return harmonic

        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)
        return jnp.concatenate(harmonics, axis=0)
    
    def __eq__(self, other: "SphericalHarmonics") -> bool:
        """
        Check if two spherical harmonic features are equal.

        Args:
            other: The other spherical harmonic features.

        Returns:
            A boolean indicating if the two features are equal.
        """
        # Given the first two parameters, the rest are deterministic. 
        # The user must not mutate all other fields, but that is not enforced for now.
        return (
            self.max_ell == other.max_ell 
            and self.sphere_dim == other.sphere_dim 
        )    


import warnings 
from typing import Optional
from jaxtyping import Num
from gpjax.typing import ScalarFloat


@jax.jit 
def mse(y_true: Float[Array, "N"], py: tfd.Distribution) -> Float[Array, ""]:
    return jnp.mean(jnp.square(y_true - py.mean()))


@jax.jit 
def nlpd(y_true: Float[Array, "N"], py: tfd.Distribution) -> Float[Array, ""]:
    return -jnp.mean(py.log_prob(y_true))



import jax 
import jax.numpy as jnp
import numpy as np
import pandas as pd
import netCDF4

from jaxtyping import Array
from jax.tree_util import Partial 
import plotly.express as px 


def sphere_uniform_grid(n: int) -> Array:
    # Fibonacci lattice method 
    phi = (1 + jnp.sqrt(5)) / 2  # Golden ratio
    
    indices = jnp.arange(n)
    theta = 2 * jnp.pi * indices / phi
    phi = jnp.arccos(1 - 2 * (indices + 0.5) / n)
    return sph_to_car(jnp.column_stack((phi, theta)))


@jax.tree_util.Partial(jax.jit, static_argnames=("with_replacement",))
def closest_point_mask(targets: Array, x: Array, with_replacement: bool) -> Array:
    """
    Args: 
        targets (Array): targets in cartesian coordinates.
        x (Array): points in cartesian coordinates for which to produce the mask.
    """

    # Can do euclidean squared distance instead of spherical, since minimisation is invariant to monotonic transformations
    distances = jnp.sum((targets[:, None] - x[None, :]) ** 2, -1)

    def closest_point_mask_with_replacement():
        return jnp.argmin(distances, axis=1)
    
    def closest_point_mask_without_replacement():
        num_targets = targets.shape[0]
        closest_indices = jnp.zeros(num_targets, dtype=jnp.int64)
        available_mask = jnp.ones(x.shape[0], dtype=bool)

        for i in range(num_targets):
            masked_distances = jnp.where(available_mask, distances[i], jnp.inf)
            closest_idx = jnp.argmin(masked_distances)
            closest_indices = closest_indices.at[i].set(closest_idx)
            available_mask = available_mask.at[closest_idx].set(False)

        return closest_indices
    
    mask_indices = jax.lax.cond(
        with_replacement, 
        closest_point_mask_with_replacement, 
        closest_point_mask_without_replacement,
    )
    mask = jnp.zeros(x.shape[0], dtype=jnp.bool)
    return mask.at[mask_indices].set(True)


def angles_to_radians_colat(x: Array) -> Array:
    return jnp.pi * x / 180 + jnp.pi / 2

def angles_to_radians_lon(x: Array) -> Array:
    return jnp.pi * x / 180 

def angles_to_radians(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: angles_to_radians_colat(df.colat),
        lon=lambda df: angles_to_radians_lon(df.lon),
    )

def radians_to_angles_colat(x: Array) -> Array:
    return 180 * x / jnp.pi - 90 

def radians_to_angles_lon(x: Array) -> Array:
    return 180 * x / jnp.pi 

def radians_to_angles(df: pd.DataFrame) -> pd.DataFrame:
    return df.assign(
        colat=lambda df: radians_to_angles_colat(df.colat),
        lon=lambda df: radians_to_angles_lon(df.lon),
    )


def match_to_uniform_grid_mask(x: Array, n: int, with_replacement: bool = True) -> Array:
    """
    Args: 
        x (Array): Points in cartesian coordinates for which to create the mask.
    """
    return closest_point_mask(
        targets=sphere_uniform_grid(n),
        x=x,
        with_replacement=with_replacement,
    )


def to_test_dataframe(df: pd.DataFrame, n: int, with_replacement: bool = True) -> pd.DataFrame:
    sph = df[['colat', 'lon']].values
    mask = match_to_uniform_grid_mask(
        x=sph_to_car(sph), n=n, with_replacement=with_replacement,
    ).tolist()
    return df[mask]



import pandas as pd 


def plot_results(x_train, y_train, x_test, y_test, mean, var, var_x, sample, history):
    var_lat, var_lon = var_x[..., 0], var_x[..., 1]
    x_train_lat, x_train_lon = x_train[:, 0], x_train[:, 1]
    y_train_dlat, y_train_dlon = y_train[:, 0], y_train[:, 1]
    x_test_lat, x_test_lon = x_test[:, 0], x_test[:, 1]
    y_test_dlat, y_test_dlon = y_test[:, 0], y_test[:, 1]
    mean_dlat, mean_dlon = mean[:, 0], mean[:, 1]


    nrows = 3
    ncols = 2
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 4), layout="constrained")

    
    # top left, prediction with test data 
    q = axs[0][0].quiver(x_test_lon, x_test_lat, mean_dlon, mean_dlat, angles="uv")
    q._init()
    scale = q.scale
    axs[0][0].quiver(x_train_lon, x_train_lat, y_train_dlon, y_train_dlat, angles="uv", color="red", scale=scale)
    axs[0][0].set_xlabel("lon")
    axs[0][0].set_ylabel("lat")
    axs[0][0].set_title("Predictive Mean")
    
    # top right, uncertainty
    c = axs[0][1].pcolormesh(var_lon, var_lat, var, vmin=var.min(), vmax=var.max())
    fig.colorbar(c, ax=axs[0][1])
    axs[0][1].set_xlabel("lon")
    axs[0][1].set_ylabel("lat")
    axs[0][1].set_title("Predictive Uncertainty")

    # middle left, true test data
    axs[1][0].quiver(x_test_lon, x_test_lat, y_test_dlon, y_test_dlat, angles="uv", scale=scale)
    axs[1][0].set_xlabel("lon")
    axs[1][0].set_ylabel("lat")
    axs[1][0].set_title("Ground truth")

    # middle right, difference 
    y_diff = y_test - mean
    y_diff_dlat, y_diff_dlon = y_diff[:, 0], y_diff[:, 1]
    axs[1][1].quiver(x_test_lon, x_test_lat, y_diff_dlon, y_diff_dlat, angles="uv", scale=scale)
    axs[1][1].set_xlabel("lon")
    axs[1][1].set_ylabel("lat")
    axs[1][1].set_title("Prediction Error")

    # bottom left, sample from posterior 
    q = axs[2][0].quiver(x_test_lon, x_test_lat, sample[:, 0], sample[:, 1], angles="uv", scale=scale)
    axs[2][0].set_xlabel("lon")
    axs[2][0].set_ylabel("lat")
    axs[2][0].set_title("Sample from Posterior")

    # bottom right, training history
    axs[2][1].plot(history)
    axs[2][1].set_xlabel("Iteration")
    axs[2][1].set_ylabel("Negative ELBO")
    axs[2][1].set_title("Training History")

    # plt.tight_layout()
    return fig


# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from gpjax.fit import (
    _check_batch_size,
    _check_log_rate,
    _check_model,
    _check_num_iters,
    _check_optim,
    _check_train_data,
    _check_verbose,
    get_batch,
)
from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
from jax import (
    jit,
    value_and_grad,
)
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import optax as ox
import scipy

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit_deep(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    train_data: Dataset,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = -1,
    log_rate: Optional[int] = 10,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
    safe: Optional[bool] = True,
) -> Tuple[ModuleModel, Array]:
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """
    if safe:
        # Check inputs.
        _check_model(model)
        _check_train_data(train_data)
        _check_optim(optim)
        _check_num_iters(num_iters)
        _check_batch_size(batch_size)
        _check_prng_key("fit", key)
        _check_log_rate(log_rate)
        _check_verbose(verbose)

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(key: Key, model: Module, batch: Dataset) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(key, model.constrain(), batch)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size != -1:
            batch = get_batch(train_data, batch_size, key)
        else:
            batch = train_data

        loss_val, loss_gradient = jax.value_and_grad(loss, argnums=1)(key, model, batch)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)

    # Constrained space.
    model = model.constrain()

    return model, history


from gpjax.base import static_field, param_field
from gpjax.kernels import AbstractKernel
from gpjax.likelihoods import AbstractLikelihood
from gpjax.gps import AbstractPosterior
import tensorflow_probability.substrates.jax.bijectors as tfb
from jax.scipy.special import gammaln
from jaxtyping import Int


@jax.jit 
def comb(N, k) -> Int:
    return jnp.round(jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))).astype(jnp.int64)


@Partial(jax.jit, static_argnames=("sphere_dim"))
def num_phases_in_frequency(sphere_dim: int, frequency: Int) -> Int:
    l, d = frequency, sphere_dim
    return jnp.where(
        l == 0, 
        jnp.ones_like(l, dtype=jnp.int64), 
        comb(l + d - 2, l - 1) + comb(l + d - 1, l),
    )


@Partial(jax.jit, static_argnames=("max_ell", "sphere_dim"))
def sphere_addition_theorem(x: Float[Array, "D"], y: Float[Array, "D"], *, max_ell: int, sphere_dim: int) -> Float:
    alpha = (sphere_dim - 1) / 2.0
    c1 = num_phases_in_frequency(sphere_dim=sphere_dim, frequency=jnp.arange(max_ell + 1))
    c2 = gegenbauer(1.0, max_ell=max_ell, alpha=alpha)
    Pz = gegenbauer(jnp.dot(x, y), max_ell=max_ell, alpha=alpha)
    return c1 / c2 * Pz


def addition_theorem_scalar_kernel(spectrum: Float[Array, "I"], z: Float[Array, "I"]) -> Float[Array, ""]:
    return jnp.dot(spectrum, z)


@Partial(jax.jit, static_argnames=('dim',))
def matern_spectrum(ell: Float, kappa: Float, nu: Float, variance: Float, dim: int) -> Float:
    lambda_ells = ell * (ell + dim - 1)
    log_Phi_nu_ells = -(nu + dim / 2) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))
    
    # Subtract max value for numerical stability
    max_log_Phi = jnp.max(log_Phi_nu_ells)
    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)
    
    # Normalize the density, so that it sums to 1
    num_harmonics_per_ell = num_phases_in_frequency(frequency=ell, sphere_dim=dim)
    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)
    return variance * Phi_nu_ells / normalizer


@dataclass
class SphereMaternKernel(Module):
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array(1.5), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.float64)
    
    def spectrum(self) -> Num[Array, "I"]:
        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)

    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "M"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return addition_theorem_scalar_kernel(
            spectrum, 
            sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
        )
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputSphereMaternKernel(Module):
    num_outputs: int = static_field()
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array([1.5]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        # shape for multioutput
        self.kappa = jnp.broadcast_to(self.kappa, (self.num_outputs,))
        self.nu = jnp.broadcast_to(self.nu, (self.num_outputs,))
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

    def __post_init__(self):
        self._validate_params()

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1)
    
    @jax.jit 
    def spectrum(self) -> Num[Array, "O L"]:
        return jax.vmap(
            lambda kappa, nu, variance: matern_spectrum(self.ells, kappa, nu, variance, dim=self.sphere_dim)
        )(self.kappa, self.nu, self.variance)
    
    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "O L"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return jax.vmap(
            lambda spectrum: addition_theorem_scalar_kernel(
                spectrum, 
                sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
            )
        )(spectrum)
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputPrior(Module):
    kernel: MultioutputSphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)

    @property 
    def num_outputs(self):
        return self.kernel.num_outputs


@jax.jit 
def grad_safe_norm(z: Float[Array, "... D"]) -> Float[Array, "..."]:
    return jnp.sqrt(jnp.clip(jnp.sum(jnp.square(z), axis=-1), min=1e-36))


@jax.jit 
def euclidean_matern32_kernel(z: Float, variance: Float) -> Float:
    return variance * (1 + jnp.sqrt(3) * z) * jnp.exp(-jnp.sqrt(3) * z)


@dataclass 
class EuclideanMaternKernel32(Module):
    num_inputs: int = static_field(init=True)
    kappa: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        assert self.kappa.shape[0] == self.num_inputs or self.kappa.shape[0] == 1

    @jax.jit
    def prepare_inputs(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return grad_safe_norm((x - y) / self.kappa)

    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return euclidean_matern32_kernel(self.prepare_inputs(x, y), self.variance)


@dataclass 
class MultioutputEuclideanMaternKernel32(Module):
    num_inputs: int = static_field()
    num_outputs: int = static_field()
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)        

        # shape for multioutput
        kappa_shape = jnp.broadcast_shapes(self.kappa.shape, (self.num_outputs, 1))
        self.kappa = jnp.broadcast_to(self.kappa, kappa_shape)
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

        assert self.kappa.shape[1] == self.num_inputs or self.kappa.shape[1] == 1

    def __post_init__(self):
        self._validate_params()

    @jax.jit
    def prepare_inputs(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return grad_safe_norm((x - y) / self.kappa)

    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return jax.vmap(euclidean_matern32_kernel)(self.prepare_inputs(x, y), self.variance)
    

@dataclass 
class Prior(Module):
    kernel: SphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)
    

@dataclass
class Posterior(Module):
    prior: Prior = param_field()
    likelihood: Module = param_field()


@dataclass
class MultioutputPosterior(Module):
    prior: MultioutputPrior = param_field()
    likelihood: Module = param_field()

    @property 
    def num_outputs(self) -> int:
        return self.prior.num_outputs
    

@Partial(jax.jit, static_argnames=('jitter',))
def inducing_points_moments(
    Kxx: Float[Array, ""], 
    Kzx: Float[Array, "M"], 
    Kzz: Float[Array, "M M"], 
    m: Float[Array, "M"], 
    sqrtS: Float[Array, "M M"], 
    jitter: float = 1e-12
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    Kzz = Kzz.at[jnp.diag_indices_from(Kzz)].add(jitter)

    Lzz = jnp.linalg.cholesky(Kzz)
    Lzz_inv_Kzx = jnp.linalg.solve(Lzz, Kzx) # [M M] @ [M] -> [M]
    sqrtS_T_Lzz_inv_Kzx = sqrtS.T @ Lzz_inv_Kzx # [M M] @ [M] -> [M]

    variance = (
        Kxx
        + jnp.sum(jnp.square(sqrtS_T_Lzz_inv_Kzx))
        - jnp.sum(jnp.square(Lzz_inv_Kzx))
        # + sqrtS_T_Lzz_inv_Kzx.T @ sqrtS_T_Lzz_inv_Kzx
        # - Lzz_inv_Kzx.T @ Lzz_inv_Kzx
    )
    variance += jitter

    mean = jnp.inner(Lzz_inv_Kzx, m) # [M] @ [M] -> []

    return mean, variance 


@Partial(jax.jit, static_argnames=('jitter',))
def spherical_harmonic_features_moments(
    Kxx: Float[Array, ""], 
    Kxz: Float[Array, "M"], 
    Kzz_inv_diag: Float[Array, "M"], 
    m: Float[Array, "M"], 
    sqrtS: Float[Array, "M M"], 
    jitter: float = 1e-12
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)
    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS

    covariance = (
        Kxx
        + jnp.sum(jnp.square(Kxz_Lzz_T_inv_sqrtS))
        # + Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T
        # - Kxz_Lzz_T_inv @ Kxz_Lzz_T_inv.T
        # No need for the term above as it is absorbed into Kxx 
    )

    mean = (
        Kxz_Lzz_T_inv @ m
    )

    return mean, covariance


@jax.jit
def whitened_prior_kl(m: Float, sqrtS: Float) -> Float:
    S = sqrtS @ sqrtS.T
    qz = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=S)

    pz = tfd.MultivariateNormalFullCovariance(
        loc=jnp.zeros(m.shape), 
        covariance_matrix=jnp.eye(m.shape[0]),
    )
    return tfd.kl_divergence(qz, pz)


def inducing_points_prior_kl(m: Float, sqrtS: Float) -> Float:
    return whitened_prior_kl(m, sqrtS)


@dataclass 
class DummyPosterior(Module):
    prior: Prior = param_field()


@dataclass 
class MultioutputDummyPosterior(Module):
    prior: MultioutputPrior = param_field()

    @property 
    def num_outputs(self):
        return self.prior.num_outputs


@dataclass 
class InducingPointsPosterior(Module):
    posterior: Posterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    def prior_kl(self) -> Float:
        return inducing_points_prior_kl(self.m, self.sqrtS)

    @jax.jit 
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "1"], Float[Array, "1"]]:
        kernel = self.posterior.prior.kernel
        z = self.z

        Kxx = kernel(x, x)
        Kzx = jax.vmap(lambda t: kernel(t, x))(z)
        Kzz = jax.vmap(lambda t1: jax.vmap(lambda t2: kernel(t1, t2))(z))(z)

        return inducing_points_moments(Kxx, Kzx, Kzz, self.m, self.sqrtS, jitter=self.jitter)

    @jax.jit
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    


@dataclass 
class MultioutputInducingPointsPosterior(Module):
    posterior: MultioutputPosterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M O"] = param_field(init=False)
    sqrtS: Float[Array, "M M O"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_outputs: int = static_field(init=False)
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.posterior.num_outputs
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros((self.num_outputs, self.num_inducing))
        self.sqrtS = jnp.repeat(jnp.expand_dims(jnp.eye(self.num_inducing), axis=0), self.num_outputs, axis=0)

    @jax.jit 
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(inducing_points_prior_kl)(self.m, self.sqrtS), axis=0)

    @jax.jit 
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "O 1"], Float[Array, "O 1"]]:
        kernel = self.posterior.prior.kernel 
        z = self.z

        Kxx = kernel(x, x)
        Kzx = jax.vmap(lambda t: kernel(t, x))(z)
        Kzz = jax.vmap(lambda t1: jax.vmap(lambda t2: kernel(t1, t2))(z))(z)

        mean, variance = jax.vmap(inducing_points_moments, in_axes=(0, 1, 2, 0, 0))(Kxx, Kzx, Kzz, self.m, self.sqrtS)

        return mean, variance

    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    

@dataclass
class SphericalHarmonicFeaturesPosterior(Module):
    posterior: Posterior = param_field()
    # spherical_harmonics: SphericalHarmonics = static_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)
        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "L"]) -> Float[Array, "M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jnp.repeat(
            spectrum[:shf.max_ell + 1], 
            repeats=repeats,
            total_repeat_length=total_repeat_length,
        )
    
    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    def prior_kl(self) -> Float[Array, ""]:
        return whitened_prior_kl(self.m, self.sqrtS)

    @jax.jit
    def moments(self, x: Float[Array, "N D"]) -> tuple[Float[Array, ""], Float[Array, ""]]:
        kernel = self.posterior.prior.kernel

        spectrum = kernel.spectrum()

        # This already accounts for the subtraction of the identity matrix from S'
        S_augment = jnp.square(self.sqrtS_augment)
        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x)
        Kzz_diag = self.Kzz_diag(spectrum)
        Kxz = self.Kxz(x)

        return spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, self.m, self.sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))


@dataclass
class MultioutputSphericalHarmonicFeaturesPosterior(Module):
    num_outputs: int = static_field(init=False)

    posterior: MultioutputPosterior = param_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_outputs = self.posterior.num_outputs
        
        num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(num_inducing)
        self.sqrtS = jnp.eye(num_inducing)
        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

        self.m = jnp.broadcast_to(self.m, (self.num_outputs, num_inducing))
        self.sqrtS = jnp.broadcast_to(self.sqrtS, (self.num_outputs, num_inducing, num_inducing))
        self.sqrtS_augment = jnp.broadcast_to(self.sqrtS_augment, (self.num_outputs, kernel.max_ell + 1))

    @jax.jit
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(whitened_prior_kl)(self.m, self.sqrtS), axis=0)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "O L"]) -> Float[Array, "O M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jax.vmap(
            lambda spectrum: jnp.repeat(spectrum, repeats=repeats, total_repeat_length=total_repeat_length)
        )(spectrum[:, :shf.max_ell + 1])
    

    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "O M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    
    @jax.jit
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "O"], Float[Array, "O"]]:
        kernel = self.posterior.prior.kernel

        # prior covariance adjusted by the diagonal variational parameters 
        spectrum = kernel.spectrum() # [O L]
        S_augment = jnp.square(self.sqrtS_augment) # [O L]
        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x) # [O N N]

        # variational covariance 
        Kzz_diag = self.Kzz_diag(spectrum) # [O M]
        Kxz = self.Kxz(x) # [O M]

        m = self.m
        sqrtS = self.sqrtS

        return jax.vmap(
            lambda Kxx, Kzz_diag, m, sqrtS: spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, m, sqrtS)
        )(Kxx, Kzz_diag, m, sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))


@jax.jit 
def expected_log_likelihood(y: Float, m: Float, f_var: Float, eps_var: Float) -> Float:
    log2pi = jnp.log(2 * jnp.pi)
    squared_error = jnp.square(y - m)
    return -0.5 * jnp.sum(log2pi + jnp.log(eps_var) + (squared_error + f_var) / eps_var, axis=-1)


@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


# TODO verify that this is correct 
@jax.jit
def sphere_expmap(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    theta = jnp.linalg.norm(v, axis=-1, keepdims=True)

    t = x + v
    first_order_approx = t / jnp.linalg.norm(t, axis=-1, keepdims=True)
    true_expmap = jnp.cos(theta) * x + jnp.sin(theta) * v / theta

    return jnp.where(
        theta < 1e-12,
        first_order_approx,
        true_expmap,
    )


@jax.jit 
def sphere_to_tangent(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    v_x = jnp.sum(x * v, axis=-1, keepdims=True)
    return v - v_x * x


@dataclass 
class SphereResidualDeepGP(Module):
    hidden_layers: list[MultioutputInducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> Posterior:
        return self.output_layer.posterior      
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()
    
    def sample_moments(self, x: Float[Array, "N D"], *, key: Key) -> tfd.Normal:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            u = sphere_to_tangent(x, v)
            x = sphere_expmap(x, u)
        return jax.vmap(self.output_layer.moments)(x)

    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MixtureSameFamily:
        sample_keys = jr.split(key, self.num_samples)

        # In MixtureSameFamily batch size goes last; hence, out_axes = 1
        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) 

        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), 
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), 
        )


@dataclass
class DeepGaussianLikelihood(Module):
    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    
    @jax.jit 
    def diag(self, pf: tfd.MixtureSameFamily) -> tfd.MixtureSameFamily:
        component_distribution = pf.components_distribution
        mean, variance = component_distribution.mean(), component_distribution.variance()
        variance += self.noise_variance
        return tfd.MixtureSameFamily(
            mixture_distribution=pf.mixture_distribution,
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)),
        )


        # component_distribution = pf.components_distribution
        # mean, variance = component_distribution.mean(), component_distribution.variance()
        # variance += self.noise_variance
        # return tfd.MixtureSameFamily(
        #     mixture_distribution=pf.mixture_distribution,
        #     components_distribution=tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.sqrt(variance)),
        # )


@dataclass 
class EuclideanDeepGP(Module):
    hidden_layers: list[InducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> MultioutputPosterior:
        return self.output_layer.posterior        
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()

    def sample_moments(self, x: Float[Array, "N D"], *, key: Key) -> tfd.Normal:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key) # [N D]
            x = x + v # euclidean expmap
        return jax.vmap(self.output_layer.moments)(x)
    
    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MixtureSameFamily:
        sample_keys = jr.split(key, self.num_samples)

        # In MixtureSameFamily batch size goes last; hence, out_axes = 1
        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) 

        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), 
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), 
        )
    

DeepGP = TypeVar("DeepGP", EuclideanDeepGP, SphereResidualDeepGP)
    

@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


from scipy.cluster.vq import kmeans
from numpy.random import Generator


def jax_key_to_numpy_generator(key: Key) -> Generator:
    return np.random.default_rng(np.asarray(key._base_array))


def kmeans_inducing_points(key: Key, x: Float[Array, "N D"], num_inducing: int) -> Float[Array, "M D"]:
    seed = jax_key_to_numpy_generator(key)

    x = np.asarray(x, dtype=np.float64)
    n = x.shape[0]

    k_centers = []
    while num_inducing > n:
        k_centers.append(kmeans(x, n)[0])
        num_inducing -= n
    k_centers.append(kmeans(x, num_inducing, seed=seed)[0])
    k_centers = np.concatenate(k_centers, axis=0)
    return jnp.asarray(k_centers, dtype=jnp.float64)


def create_residual_deep_gp_with_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = False, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    hidden_z = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
    output_z = hidden_z

    if kernel_max_ell is None:
        kernel_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim)

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z) # TODO set z to be trainable maybe 
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def num_phases_to_num_levels(num_phases: int, *, sphere_dim: int) -> int:
    l = 0
    while num_phases > 0:
        num_phases -= num_phases_in_frequency(frequency=l, sphere_dim=sphere_dim)
        l += 1
    return l - 1 if num_phases == 0 else l - 2


def create_residual_deep_gp_with_spherical_harmonic_features(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = jnp.array(total_hidden_variance / max(num_layers - 1, 1))
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    shf_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim)
    if kernel_max_ell is None:
        kernel_max_ell = shf_max_ell
    hidden_spherical_harmonics = SphericalHarmonics(max_ell=shf_max_ell, sphere_dim=sphere_dim)
    output_spherical_harmonics = hidden_spherical_harmonics

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputSphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=hidden_spherical_harmonics)
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_euclidean_deep_gp_with_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = True
) -> EuclideanDeepGP:
    sphere_dim = x.shape[1] - 1
    
    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.ones(sphere_dim + 1)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    hidden_z = z
    output_z = z

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputEuclideanMaternKernel32(
            variance=hidden_variance,
            kappa=hidden_kappa, 
            num_outputs=sphere_dim + 1, 
            num_inputs=sphere_dim + 1,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z)
        if train_inducing:
            layer = layer.replace_trainable(z=True)
        hidden_layers.append(layer)


    kernel = EuclideanMaternKernel32(
        num_inputs=sphere_dim + 1,
        variance=output_variance,
        kappa=output_kappa,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return EuclideanDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_euclidean_deep_gp_with_input_geometric_layer_and_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = True, 
) -> EuclideanDeepGP:
    if num_layers == 1:
        return create_residual_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    
    sphere_dim = x.shape[1] - 1

    input_nu = jnp.array(1.5)

    input_variance = total_hidden_variance / max(num_layers - 1, 1)
    hidden_variance = input_variance
    output_variance = jnp.array(1.0)

    input_kappa = jnp.array(1.0)
    hidden_kappa = jnp.ones(sphere_dim + 1)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    input_z = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
    hidden_z = z
    output_z = z


    hidden_layers = []

    # input layer
    kernel = MultioutputSphereMaternKernel(
        num_outputs=sphere_dim + 1, 
        sphere_dim=sphere_dim, 
        variance=input_variance,
        kappa=input_kappa,
        nu=input_nu,
    )
    prior = MultioutputPrior(kernel=kernel)
    posterior = MultioutputDummyPosterior(prior=prior)
    layer = MultioutputInducingPointsPosterior(posterior=posterior, z=input_z)
    hidden_layers.append(layer)

    for _ in range(num_layers - 2):
        kernel = MultioutputEuclideanMaternKernel32(
            kappa=hidden_kappa,
            variance=hidden_variance, 
            num_inputs=sphere_dim + 1,
            num_outputs=sphere_dim + 1,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z)
        if train_inducing:
            layer = layer.replace_trainable(z=True)
        hidden_layers.append(layer)

    kernel = EuclideanMaternKernel32(
        kappa=output_kappa,
        variance=output_variance,
        num_inputs=sphere_dim + 1,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return EuclideanDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_model(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, name: str, nu: float = 1.5, 
    kernel_max_ell: int | None = None
) -> DeepGP: 
    if name == 'residual+inducing_points':
        return create_residual_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        ) 
    if name == 'euclidean+inducing_points':
        return create_euclidean_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'euclidean_with_geometric_input+inducing_points':
        return create_euclidean_deep_gp_with_input_geometric_layer_and_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'residual+spherical_harmonic_features':
        return create_residual_deep_gp_with_spherical_harmonic_features(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        )
    raise ValueError(f"Unknown model name: {name}")



# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
import jax.random as jr
import optax as ox

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    x: Float, 
    y: Float,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = None,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
) -> Tuple[ModuleModel, Array]:
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(model: Module, x: Float, y: Float, *, key: Key) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(model.constrain(), x, y, key=key)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size is not None:
            batch_x, batch_y = get_batch(x, y, batch_size, key)
        else:
            batch_x, batch_y = x, y

        loss_val, loss_gradient = jax.value_and_grad(loss)(model, batch_x, batch_y, key=key)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    (model, _), history = scan(step, (model, state), (iter_keys), unroll=unroll)

    # Constrained space.
    model = model.constrain()

    return model, history


def get_batch(x: Float, y: Float, batch_size: int, key: KeyArray) -> tuple[Float, Float]:
    """Batch the data into mini-batches. Sampling is done with replacement.

    Args:
        train_data (Dataset): The training dataset.
        batch_size (int): The batch size.
        key (KeyArray): The random key to use for the batch selection.

    Returns
    -------
        Dataset: The batched dataset.
    """
    n = x.shape[0]

    # Subsample mini-batch indices with replacement.
    indices = jr.choice(key, n, (batch_size,), replace=True)

    return x[indices], y[indices]


@Partial(jax.jit, static_argnames=('n',))
def deep_negative_elbo(p: EuclideanDeepGP, x: Float, y: Float, *, key: Key, n: int) -> Float:
    eps_var = p.posterior.likelihood.noise_variance
    sample_keys = jr.split(key, p.num_samples)

    def sample_expected_log_likelihood(key: Key) -> Float:
        m, f_var = p.sample_moments(x, key=key)
        return expected_log_likelihood(y, m, f_var, eps_var)
    
    deep_expected_log_likelihood = jnp.mean(jax.vmap(sample_expected_log_likelihood)(sample_keys), axis=0)
    batch_ratio_correction = n / x.shape[0]

    return -(deep_expected_log_likelihood * batch_ratio_correction - p.prior_kl())


import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from scipy.special import sph_harm


def sphere_meshgrid(n_colat=100, n_lon=100, epsilon=1e-32):
    colat = np.linspace(0 + epsilon, np.pi - epsilon, n_colat)
    lon = np.linspace(0, 2 * np.pi, n_lon)
    lon, colat = np.meshgrid(lon, colat)
    return colat, lon


def target_f__product_of_sines(sph: Float) -> Float:
    """
    Irregular function with singularities.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    return jnp.sin(colat * 3) * jnp.sin(lon * 3)


def target_f__sum_of_sines(sph: Float) -> Float:
    colat, lon = sph[..., 0], sph[..., 1]
    return (np.abs(np.sin(colat * 2)) + np.abs(np.sin(lon * 2)) - 1.0) / 2



def rotate(sph, roll: float):
    """
    Apply a roll rotation to the points in spherical coordinates (colatitude, longitude) in [0, pi] x [0, 2pi].
    
    Parameters:
    colatitude (float or jnp.ndarray): Colatitude values in the range [0, pi].
    longitude (float or jnp.ndarray): Longitude values in the range [0, 2pi].
    angle (float): Angle of rotation in radians.
    
    Returns:
    jnp.ndarray: Array of transformed colatitude and longitude values.
    """
    colatitude, longitude = sph[..., 0], sph[..., 1]

    # Convert to Cartesian coordinates
    x = jnp.sin(colatitude) * jnp.cos(longitude)
    y = jnp.sin(colatitude) * jnp.sin(longitude)
    z = jnp.cos(colatitude)
    
    # Apply roll rotation
    new_x = x
    new_y = jnp.cos(roll) * y - jnp.sin(roll) * z
    new_z = jnp.sin(roll) * y + jnp.cos(roll) * z
    
    # Convert back to spherical coordinates
    new_colatitude = jnp.arccos(new_z / jnp.sqrt(new_x**2 + new_y**2 + new_z**2))
    new_longitude = jnp.arctan2(new_y, new_x)
    
    return jnp.stack([new_colatitude, new_longitude], axis=-1)


def reversed_spherical_harmonic(sph, m: int, n: int):
    colat, lon = sph[..., 0], sph[..., 1]
    return jnp.asarray(sph_harm(m, n, np.asarray(colat), np.asarray(lon)).real)


def target_f__reversed_spherical_harmonic(sph: Float) -> Float:
    return reversed_spherical_harmonic(sph, m=1, n=2) + reversed_spherical_harmonic(rotate(sph, roll=jnp.pi / 2), m=1, n=1)


def target_f(sph: Float[Array, " N D"], *, name: str = "product_of_sines") -> Float[Array, " N"]:
    if name == "product_of_sines":
        return target_f__product_of_sines(sph)
    if name == "sum_of_sines":
        return target_f__sum_of_sines(sph)
    if name == "reversed_spherical_harmonic":
        return target_f__reversed_spherical_harmonic(sph)
    raise ValueError(f"Unknown target function: {name}")


def add_noise(f: Float[Array, " N"], noise_std: float = 0.01, *, key: Key) -> Float:
    return f + jax.random.normal(key=key, shape=f.shape) * noise_std


def evaluate_diag(model: DeepGP, x: Float[Array, "N D"], y: Float[Array, " N"], *, key: Key) -> dict[str, float]:
    py = model.posterior.likelihood.diag(model.diag(x, key=key))
    metrics = {
        "mse": mse(y, py).item(),
        "nlpd": nlpd(y, py).item(),
    }
    return metrics 


def mse_projected(y: Float[Array, "N"], py: tfd.MixtureSameFamily, x_norm: Float[Array, "N"]) -> Float:
    return jnp.mean(jnp.square(y - py.mean()) * jnp.square(x_norm))


def nlpd_projected(y: Float[Array, "N"], py: tfd.MixtureSameFamily, x_norm: Float[Array, "N"]) -> Float:
    return -jnp.mean(py.log_prob(y) - jnp.log(x_norm))


def evaluate_diag_projected(model: DeepGP, x: Float[Array, "N D"], y: Float[Array, " N"], *, key: Key, x_norm: Float[Array, "N"]) -> dict[str, float]:
    py = model.posterior.likelihood.diag(model.diag(x, key=key))
    metrics = {
        "mse": mse_projected(y, py, x_norm).item(),
        "nlpd": nlpd_projected(y, py, x_norm).item(),
    }
    return metrics 


from jax import config
config.update("jax_enable_x64", True)


from pathlib import Path 
import numpy as np 
import pandas as pd 
import jax 
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray
from jaxtyping import Array, Key, ArrayLike
from scipy.io import arff


experiment_dir = Path("../data/")


def read_yacht_data() -> tuple[NDArray, NDArray]:
    yacht_df = pd.read_csv(experiment_dir / f"yacht.data", header=None, sep='\s+').astype(np.float64)
    X, y = yacht_df.iloc[:, :-1].values, yacht_df.iloc[:, -1:].values
    return X, y 


def read_energy_data() -> tuple[NDArray, NDArray]:
    """
    We take the Heating Load as the target variable, which is second to last in the DataFrame.
    """
    energy_df = pd.read_excel(experiment_dir / f"energy.xlsx").astype(np.float64)
    X, y = energy_df.iloc[:, :-2].values, energy_df.iloc[:, -2:-1].values
    return X, y


def read_concrete_data() -> tuple[NDArray, NDArray]:
    concrete_df = pd.read_excel(experiment_dir / f"concrete.xls").astype(np.float64)
    X, y = concrete_df.iloc[:, :-1].values, concrete_df.iloc[:, -1:].values
    return X, y


def read_kin8mn_data() -> tuple[NDArray, NDArray]:
    data, _ = arff.loadarff(experiment_dir / f"kin8nm.arff")
    data = np.array(data.tolist(), dtype=np.float64)
    X, y = data[:, :-1], data[:, -1:]
    return X, y


def read_power_data() -> tuple[NDArray, NDArray]:
    power_df = pd.read_excel(experiment_dir / f"power.xlsx").astype(np.float64)
    X, y = power_df.iloc[:, :-1].values, power_df.iloc[:, -1:].values
    return X, y


def to_jax(*args: ArrayLike) -> tuple[Array, ...]:
    return tuple(jnp.asarray(arg, dtype=jnp.float64) for arg in args)


def train_test_split(X: Array, y: Array, key: Key) -> tuple[Array, Array, Array, Array]:
    """
    Split the data into training and testing sets with a 90-10 split.
    """
    n = X.shape[0]
    split = int(0.9 * n)

    perm = jax.random.permutation(key, n)
    X, y = X[perm], y[perm]
    X_train, y_train = X[:split], y[:split]
    X_test, y_test = X[split:], y[split:]
    return X_train, X_test, y_train, y_test


def standardize(X_train: Array, X_test: Array, y_train: Array, y_test: Array) -> tuple[Array, Array, Array, Array]:
    """
    Standardize the training and testing data. Does not return the mean and standard deviation, 
    since the NLPD and MSE is reported on the standardized data.
    """
    X_train_mean = X_train.mean(axis=0)
    X_train_std = X_train.std(axis=0)
    y_train_mean = y_train.mean(axis=0)
    y_train_std = y_train.std(axis=0)

    X_train = (X_train - X_train_mean) / X_train_std
    X_test = (X_test - X_train_mean) / X_train_std
    y_train = (y_train - y_train_mean) / y_train_std
    y_test = (y_test - y_train_mean) / y_train_std

    return X_train, X_test, y_train, y_test


def read_data(dataset: str) -> tuple[NDArray, NDArray]:
    if dataset == "yacht":
        return read_yacht_data()
    elif dataset == "energy":
        return read_energy_data()
    elif dataset == "concrete":
        return read_concrete_data()
    elif dataset == "kin8mn":
        return read_kin8mn_data()
    elif dataset == "power":
        return read_power_data()
    else:
        raise ValueError(f"Dataset {dataset} not found.")
    

def project_data_to_sphere(X: Array, y: Array, bias: float = 1.0):
    X_projected = jnp.empty((X.shape[0], X.shape[1] + 1), dtype=X.dtype)
    X_projected = X_projected.at[:, :-1].set(X)
    X_projected = X_projected.at[:, -1].set(bias)
    X_projected_norm = jnp.linalg.norm(X_projected, axis=1, keepdims=True)
    return X_projected / X_projected_norm, y / X_projected_norm, X_projected_norm


def train_and_eval(
    dataset_name: str, model_name: str, num_layers: int, seed: int, kernel_max_ell: int | None = None, num_iters: int = 1000, batch_size: int = None,
): 
    key = jax.random.key(seed)
    key, data_key = jax.random.split(key)

    # data 
    x, y = read_data(dataset_name)
    x, y = to_jax(x, y)
    train_x, test_x, train_y, test_y = train_test_split(x, y, key=data_key)
    train_x, test_x, train_y, test_y = standardize(train_x, test_x, train_y, test_y)

    if model_name.startswith("residual"):
        train_x, train_y, _ = project_data_to_sphere(train_x, train_y)
        test_x, test_y, test_x_norm = project_data_to_sphere(test_x, test_y)
        test_x_norm = test_x_norm.squeeze(axis=-1)
    else:
        test_x_norm = jnp.ones((test_x.shape[0]))
    
    train_y, test_y = train_y.squeeze(axis=-1), test_y.squeeze(axis=-1)

    # model
    key, model_key = jax.random.split(key)

    num_inducing = dataset_name_to_num_inducing[dataset_name]
    model = create_model(
        num_layers=num_layers, total_hidden_variance=total_hidden_variance, 
        num_inducing=num_inducing, x=train_x, num_samples=train_num_samples, 
        name=model_name, key=model_key, kernel_max_ell=kernel_max_ell,
    )

    # training
    optim = ox.adam(learning_rate=lr)
    if batch_size is not None and batch_size >= train_x.shape[0]:
        batch_size = None
    n = train_x.shape[0]
    objective = Partial(deep_negative_elbo, n=n)

    model_opt, history = fit(
        model=model, 
        objective=objective, 
        x=train_x, 
        y=train_y, 
        optim=optim, 
        key=key, 
        num_iters=num_iters,
        batch_size=batch_size,
    )


    # testing
    model_opt = model_opt.replace(num_samples=test_num_samples)
    metrics_diag = evaluate_diag_projected(model_opt, test_x, test_y, key=key, x_norm=test_x_norm)
    # metrics_diag = evaluate_diag(model_opt, test_x, test_y, key=key)


    return model_opt, history, metrics_diag


import os 


def save_results(experiment_settings: dict[str, Any], metrics: dict[str, float], *, dir_path: str):
    experiment_name = "-".join(f"{k}={v}" for k, v in experiment_settings.items())
    experiment_dir = os.path.join(dir_path, experiment_name)

    os.makedirs(experiment_dir, exist_ok=True)

    # save metrics
    metrics_file_name = os.path.join(experiment_dir, "metrics.csv")
    pd.DataFrame(metrics, index=[0]).to_csv(metrics_file_name)
    return 


def main(dataset_name: str, model_name: str, num_layers: int, seed: int, kernel_max_ell: int | None = None, num_iters: int = 1000, batch_size: int = None):
    experiment_settings = {
        "dataset_name": dataset_name,
        "model_name": model_name,
        "num_layers": num_layers,
        "seed": seed,
        "kernel_max_ell": kernel_max_ell,
        "num_iters": num_iters,
        "batch_size": batch_size,
    }
    print(experiment_settings)

    model_opt, history, metrics = train_and_eval(
        dataset_name=dataset_name, model_name=model_name, num_layers=num_layers, seed=seed,
        kernel_max_ell=kernel_max_ell, num_iters=num_iters, batch_size=batch_size,
    )
    print(metrics)
    save_results(
        experiment_settings=experiment_settings,
        metrics=metrics,
        dir_path="results/uci",
    )
    return model_opt



import chex 
from typing import NamedTuple, Any


def tree_max(tree: Any) -> chex.Numeric:
  """Compute the sum of all the elements in a pytree.
  """
  maxes = jax.tree_util.tree_map(jnp.max, tree)
  return jax.tree_util.tree_reduce(jax.lax.max, maxes)


def tree_linf_norm(tree: Any) -> chex.Numeric:
  """Compute the l-infinity norm of a pytree.
  """
  norms = jax.tree_util.tree_map(jnp.abs, tree)
  return tree_max(norms)


def run_lbfgs(init_params, fun, opt, max_iter, *, key: Key):
  # value_and_grad_fun = optax.value_and_grad_from_state(fun)

  def step(carry):
    params, state, key = carry
    key, subkey = jax.random.split(key)
    value_fun = Partial(fun, key=subkey)

    value_and_grad_fun = optax.value_and_grad_from_state(value_fun)
    value, grad = value_and_grad_fun(params, state=state)
    updates, state = opt.update(
        grad, state, params, value=value, grad=grad, value_fn=value_fun
    )
    params = optax.apply_updates(params, updates)
    return params, state, key
  
  def continuing_criterion(carry):
    _, state, _ = carry
    iter_num = optax.tree_utils.tree_get(state, 'count')
    gval = optax.tree_utils.tree_get(state, 'g')
    fval = optax.tree_utils.tree_get(state, 'f')
    gtol = optax.tree_utils.tree_get(state, 'gtol')
    ftol = optax.tree_utils.tree_get(state, 'ftol')

    return (iter_num == 0) | ((iter_num < max_iter) & (gval > gtol) & (fval > ftol))

  init_carry = (init_params, opt.init(init_params), key)
  final_params, final_state, _ = jax.lax.while_loop(
      continuing_criterion, step, init_carry
  )
  return final_params, final_state


class InfoState(NamedTuple):
  iter_num: chex.Numeric


def print_info():
  def init_fn(params):
    del params
    return InfoState(iter_num=0)

  def update_fn(updates, state, params, *, value, grad, **extra_args):
    del params, extra_args, grad
    jax.debug.print(
        'Iteration: {i}, Value: {v} \r',
        i=state.iter_num,
        v=value
    )
    return updates, InfoState(iter_num=state.iter_num + 1)

  return optax.GradientTransformationExtraArgs(init_fn, update_fn)


class LBFGSTerminationCriteriaState(NamedTuple):
  g: chex.Numeric
  f: chex.Numeric
  f_kp1: chex.Numeric
  f_k: chex.Numeric
  gtol: chex.Numeric = 1e-5
  ftol: chex.Numeric = 2.2204460492503131e-09

def lbfgs_termination_criteria():
  def init_fn(params):
    del params
    return LBFGSTerminationCriteriaState(g=jnp.inf, f=jnp.inf, f_kp1=jnp.inf, f_k=jnp.inf)

  def update_fn(updates, state: LBFGSTerminationCriteriaState, params, *, value, grad, **extra_args):
    del params, extra_args

    # Compute g
    new_g = tree_linf_norm(grad)

    # Compute f
    new_f_k = state.f_kp1
    new_f_kp1 = value

    denominator = jax.lax.max(jax.lax.max(jnp.abs(new_f_k), jnp.abs(new_f_kp1)), jnp.array(1.0, dtype=jnp.float64))
    new_f = jax.lax.cond(
      jax.lax.is_finite(new_f_k),
      lambda: (new_f_k - new_f_kp1) / denominator,
      lambda: jnp.inf,
    )

    return updates, LBFGSTerminationCriteriaState(g=new_g, f=new_f, f_k=new_f_k, f_kp1=new_f_kp1)

  return optax.GradientTransformationExtraArgs(init_fn, update_fn)


def fit_lbfgs(model: Module, train_x: Float[Array, "N D"], train_y: Float[Array, "N"], opt, max_iter: int, *, key: Key):
    model = model.unconstrain()

    def scaled_elbo(model: gpjax.Module, train_x: Float[Array, "N D"], train_y: Float[Array, "N"], *, key: Key) -> Float:
        return deep_negative_elbo(model, train_x, train_y, key=key) / train_x.shape[0]

    def objective(model: Module, *, key: Key):
        return scaled_elbo(model.stop_gradient().constrain(), train_x, train_y, key=key)
    
    model, _ = run_lbfgs(init_params=model, fun=objective, opt=opt, max_iter=max_iter, key=key)
    return model.constrain()


dataset_names = [
    "yacht",
    "concrete",
    "energy",
    "kin8mn",
    "power",
]


dataset_name_to_num_inducing = {
    "yacht": 294,
    "concrete": 210, 
    "energy": 210,
    "kin8mn": 210,
    "power": 336,
}


dataset_name_to_kernel_max_ell = {
    "yacht": 12, 
    "concrete": 10, 
    "energy": 10,
    "kin8mn": 10,
    "power": 20,
}


model_names = [
    "residual+inducing_points", 
    "euclidean+inducing_points", 
    "residual+spherical_harmonic_features", 
]


nums_layers = [
    1, 2, 3, 4, 5,
]


num_test = 5000


total_hidden_variance = 0.0001
train_num_samples = 3
test_num_samples = 10

lr = 0.01
num_iters = 5000


import argparse


if __name__ == '__main__':
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--dataset", type=str, required=True)
    argparser.add_argument("--model", type=str, required=True)
    argparser.add_argument("--num_layers", type=int, required=True)
    argparser.add_argument("--seed", type=int, required=True)
    argparser.add_argument("--kernel_max_ell", type=lambda x: None if x.lower() == 'none' else int(x), default=None)
    argparser.add_argument("--num_iters", type=int, default=5000)
    argparser.add_argument("--batch_size", type=lambda x: None if x.lower() == 'none' else int(x), default=None)

    args = argparser.parse_args()
    main(
        dataset_name=args.dataset, 
        model_name=args.model, 
        num_layers=args.num_layers, 
        seed=args.seed, 
        kernel_max_ell=args.kernel_max_ell,
        num_iters=args.num_iters,
        batch_size=args.batch_size,
    )