from jax import config 
config.update("jax_enable_x64", True)


import jax 
import jax.numpy as jnp
import gpjax 
import optax 
from gpjax.distributions import GaussianDistribution
from gpjax.typing import Float
from gpjax.base import static_field, param_field
from jax.tree_util import Partial
import tensorflow_probability.substrates.jax as tfp
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Key

from matplotlib import pyplot as plt 

from dataclasses import dataclass, InitVar
from abc import abstractmethod


@jax.jit
def sph_dot_product(sph1: Float, sph2: Float) -> Float:
    """
    Computes dot product in R^3 of two points on the sphere in spherical coordinates.
    """
    colat1, lon1 = sph1[..., 0], sph1[..., 1]
    colat2, lon2 = sph2[..., 0], sph2[..., 1]
    return jnp.sin(colat1) * jnp.sin(colat2) * jnp.cos(lon1 - lon2) + jnp.cos(colat1) * jnp.cos(colat2)


# NOTE jitting this doesn't help
def sph_gegenbauer(x, y, max_ell: int, alpha: float = 0.5):
    return gegenbauer(x=sph_dot_product(x, y), max_ell=max_ell, alpha=alpha)


# NOTE jitting this doesn't help
def sph_gegenbauer_single(x, y, ell: int, alpha: float = 0.5):
    return gegenbauer_single(x=sph_dot_product(x, y), ell=ell, alpha=alpha)


def array(x):
    return jnp.array(x, dtype=jnp.float64)


@jax.jit
def sph_to_car(sph):
    """
    From spherical (colat, lon) coordinates to cartesian, single point.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    lon = (lon + 2 * jnp.pi) % (2 * jnp.pi)
    return jnp.stack([colat, lon], axis=-1)



from pathlib import Path
from typing import Callable

import numpy as np
from jax import Array


class FundamentalSystemNotPrecomputedError(ValueError):

    def __init__(self, dimension: int):
        message = f"Fundamental system for dimension {dimension} has not been precomputed."
        super().__init__(message)


def fundamental_set_loader(dimension: int, load_dir="fundamental_system") -> Callable[[int], Array]:
    load_dir = Path("../") / load_dir
    file_name = load_dir / f"fs_{dimension}D.npz"

    cache = {}
    if file_name.exists():
        with np.load(file_name) as f:
            cache = {k: v for (k, v) in f.items()}
    else:
        raise FundamentalSystemNotPrecomputedError(dimension)

    def load(degree: int) -> Array:
        key = f"degree_{degree}"
        if key not in cache:
            raise ValueError(f"key: {key} not in cache.")
        return cache[key]

    return load


@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))
def gegenbauer(x: Float[Array, "N D"], max_ell: int, alpha: float = 0.5) -> Float[Array, "N L"]:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x
    
    res = jnp.empty((*x.shape, max_ell + 1), dtype=x.dtype)
    res = res.at[..., 0].set(C_0)

    def step(n: int, res_and_Cs: tuple[Float, Float, Float]) -> tuple[Float, Float, Float]:
        res, C, C_prev = res_and_Cs
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        res = res.at[..., n].set(C)
        return res, C, C_prev
    
    return jax.lax.cond(
        max_ell == 0,
        lambda: res,
        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[..., 1].set(C_1), C_1, C_0))[0],
    )


@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call 
def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x

    def step(Cs_and_n):
        C, C_prev, n = Cs_and_n
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        return C, C_prev, n + 1

    def cond(Cs_and_n):
        n = Cs_and_n[2]
        return n <= ell

    return jax.lax.cond(
        ell == 0,
        lambda: C_0,
        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],
    )


@dataclass
class SphericalHarmonics(gpjax.Module):
    """
    Spherical harmonics inducing features for sparse inference in Gaussian processes.

    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,
    as such, they form an orthogonal basis.

    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute
    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via
    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during
    initialisation.

    Attributes:
        num_frequencies: The number of frequencies, up to which, we compute the harmonics.

    Returns:
        An instance of the spherical harmonics features.
    """

    max_ell: int = static_field()
    sphere_dim: int = static_field()
    alpha: float = static_field(init=False)
    orth_basis: Array = static_field(init=False)
    Vs: list[Array] = static_field(init=False)
    num_phases_per_frequency: Float[Array, " L"] = static_field(init=False)
    num_phases: int = static_field(init=False)


    @property
    def levels(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)
    

    def __post_init__(self) -> None:
        """
        Initialise the parameters of the spherical harmonic features and return a `Param` object.

        Returns:
            None
        """
        dim = self.sphere_dim + 1

        # Try loading a pre-computed fundamental set.
        fund_set = fundamental_set_loader(dim)

        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.
        self.alpha = (dim - 2.0) / 2.0

        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.
        self.Vs = [fund_set(n) for n in self.levels]

        # pre-compute and save the orthogonal basis 
        self.orth_basis = self._orthogonalise_basis()


        # set these things instead of computing every time 
        self.num_phases_per_frequency = [v.shape[0] for v in self.Vs]
        self.num_phases = sum(self.num_phases_per_frequency)


    @property
    def Ls(self) -> list[Array]:
        """
        Alias for the orthogonal basis at every frequency.
        """
        return self.orth_basis

    def _orthogonalise_basis(self) -> None:
        """
        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.
        """
        alpha = self.alpha
        levels = jnp.split(self.levels, self.max_ell + 1)
        const = alpha / (alpha + self.levels.astype(jnp.float64))
        const = jnp.split(const, self.max_ell + 1)

        def _func(v, n, c):
            x = jnp.matmul(v, v.T)
            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)
            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))
            return L

        return jax.tree.map(_func, self.Vs, levels, const)

    def custom_gegenbauer_single(self, x, ell, alpha):
        return gegenbauer(x, self.max_ell, alpha)[..., ell]

    @jax.jit
    def polynomial_expansion(self, X: Float[Array, "N D"]) -> Float[Array, "M N"]:
        """
        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.

        Args:
            X: Input Array.

        Returns:
            The harmonics evaluated at the input as a polynomial expansion of the basis.
        """
        levels = jnp.split(self.levels, self.max_ell + 1)

        def _func(v, n, L):
            vxT = jnp.dot(v, X.T)
            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)
            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)
            return harmonic

        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)
        return jnp.concatenate(harmonics, axis=0)
    
    def __eq__(self, other: "SphericalHarmonics") -> bool:
        """
        Check if two spherical harmonic features are equal.

        Args:
            other: The other spherical harmonic features.

        Returns:
            A boolean indicating if the two features are equal.
        """
        # Given the first two parameters, the rest are deterministic. 
        # The user must not mutate all other fields, but that is not enforced for now.
        return (
            self.max_ell == other.max_ell 
            and self.sphere_dim == other.sphere_dim 
        )    


import warnings 
from typing import Optional
from jaxtyping import Num
from gpjax.typing import ScalarFloat



import jax 
import jax.numpy as jnp
import numpy as np
import pandas as pd
import netCDF4

from jaxtyping import Array
from jax.tree_util import Partial 
import plotly.express as px 

import pandas as pd 


from gpjax.base import static_field, param_field
from gpjax.kernels import AbstractKernel
from gpjax.likelihoods import AbstractLikelihood
from gpjax.gps import AbstractPosterior
import tensorflow_probability.substrates.jax.bijectors as tfb
from jax.scipy.special import gammaln
from jaxtyping import Int


@jax.jit 
def comb(N, k) -> Int:
    return jnp.round(jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))).astype(jnp.int64)


@Partial(jax.jit, static_argnames=("sphere_dim"))
def num_phases_in_frequency(sphere_dim: int, frequency: Int) -> Int:
    l, d = frequency, sphere_dim
    return jnp.where(
        l == 0, 
        jnp.ones_like(l, dtype=jnp.int64), 
        comb(l + d - 2, l - 1) + comb(l + d - 1, l),
    )


@Partial(jax.jit, static_argnames=("max_ell", "sphere_dim"))
def sphere_addition_theorem(x: Float[Array, "D"], y: Float[Array, "D"], *, max_ell: int, sphere_dim: int) -> Float:
    alpha = (sphere_dim - 1) / 2.0
    c1 = num_phases_in_frequency(sphere_dim=sphere_dim, frequency=jnp.arange(max_ell + 1))
    c2 = gegenbauer(1.0, max_ell=max_ell, alpha=alpha)
    Pz = gegenbauer(jnp.dot(x, y), max_ell=max_ell, alpha=alpha)
    return c1 / c2 * Pz


def addition_theorem_scalar_kernel(spectrum: Float[Array, "I"], z: Float[Array, "I"]) -> Float[Array, ""]:
    return jnp.dot(spectrum, z)


@Partial(jax.jit, static_argnames=('dim',))
def matern_spectrum(ell: Float, kappa: Float, nu: Float, variance: Float, dim: int) -> Float:
    lambda_ells = ell * (ell + dim - 1)
    log_Phi_nu_ells = -(nu + dim / 2) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))
    
    # Subtract max value for numerical stability
    max_log_Phi = jnp.max(log_Phi_nu_ells)
    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)
    
    # Normalize the density, so that it sums to 1
    num_harmonics_per_ell = num_phases_in_frequency(frequency=ell, sphere_dim=dim)
    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)
    return variance * Phi_nu_ells / normalizer


from gpjax.base import Module


@dataclass
class SphereMaternKernel(Module):
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array(1.5), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.float64)
    
    def spectrum(self) -> Num[Array, "I"]:
        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)

    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "M"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return addition_theorem_scalar_kernel(
            spectrum, 
            sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
        )
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputSphereMaternKernel(Module):
    num_outputs: int = static_field()
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array([1.5]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        # shape for multioutput
        self.kappa = jnp.broadcast_to(self.kappa, (self.num_outputs,))
        self.nu = jnp.broadcast_to(self.nu, (self.num_outputs,))
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

    def __post_init__(self):
        self._validate_params()

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1)
    
    @jax.jit 
    def spectrum(self) -> Num[Array, "O L"]:
        return jax.vmap(
            lambda kappa, nu, variance: matern_spectrum(self.ells, kappa, nu, variance, dim=self.sphere_dim)
        )(self.kappa, self.nu, self.variance)
    
    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "O L"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return jax.vmap(
            lambda spectrum: addition_theorem_scalar_kernel(
                spectrum, 
                sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
            )
        )(spectrum)
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputPrior(Module):
    kernel: MultioutputSphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)

    @property 
    def num_outputs(self):
        return self.kernel.num_outputs


@jax.jit 
def grad_safe_norm(z: Float[Array, "... D"]) -> Float[Array, "..."]:
    return jnp.sqrt(jnp.clip(jnp.sum(jnp.square(z), axis=-1), min=1e-36))


@jax.jit 
def euclidean_matern32_kernel(z: Float, variance: Float) -> Float:
    return variance * (1 + jnp.sqrt(3) * z) * jnp.exp(-jnp.sqrt(3) * z)


@dataclass 
class EuclideanMaternKernel32(Module):
    num_inputs: int = static_field(init=True)
    kappa: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: Float = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        assert self.kappa.shape[0] == self.num_inputs or self.kappa.shape[0] == 1

    @jax.jit
    def prepare_inputs(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return grad_safe_norm((x - y) / self.kappa)

    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return euclidean_matern32_kernel(self.prepare_inputs(x, y), self.variance)


@dataclass 
class MultioutputEuclideanMaternKernel32(Module):
    num_inputs: int = static_field()
    num_outputs: int = static_field()
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)        

        # shape for multioutput
        kappa_shape = jnp.broadcast_shapes(self.kappa.shape, (self.num_outputs, 1))
        self.kappa = jnp.broadcast_to(self.kappa, kappa_shape)
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

        assert self.kappa.shape[1] == self.num_inputs or self.kappa.shape[1] == 1

    def __post_init__(self):
        self._validate_params()

    @jax.jit
    def prepare_inputs(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return grad_safe_norm((x - y) / self.kappa)

    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return jax.vmap(euclidean_matern32_kernel)(self.prepare_inputs(x, y), self.variance)
    

@dataclass 
class Prior(Module):
    kernel: SphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)
    

@dataclass
class Posterior(Module):
    prior: Prior = param_field()
    likelihood: Module = param_field()


@dataclass
class MultioutputPosterior(Module):
    prior: MultioutputPrior = param_field()
    likelihood: Module = param_field()

    @property 
    def num_outputs(self) -> int:
        return self.prior.num_outputs
    

@Partial(jax.jit, static_argnames=('jitter',))
def inducing_points_moments(
    Kxx: Float[Array, ""], 
    Kzx: Float[Array, "M"], 
    Kzz: Float[Array, "M M"], 
    m: Float[Array, "M"], 
    sqrtS: Float[Array, "M M"], 
    jitter: float = 1e-12
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    Kzz = Kzz.at[jnp.diag_indices_from(Kzz)].add(jitter)

    Lzz = jnp.linalg.cholesky(Kzz)
    Lzz_inv_Kzx = jnp.linalg.solve(Lzz, Kzx) # [M M] @ [M] -> [M]
    sqrtS_T_Lzz_inv_Kzx = sqrtS.T @ Lzz_inv_Kzx # [M M] @ [M] -> [M]

    variance = (
        Kxx
        + jnp.sum(jnp.square(sqrtS_T_Lzz_inv_Kzx))
        - jnp.sum(jnp.square(Lzz_inv_Kzx))
        # + sqrtS_T_Lzz_inv_Kzx.T @ sqrtS_T_Lzz_inv_Kzx
        # - Lzz_inv_Kzx.T @ Lzz_inv_Kzx
    )
    variance += jitter

    mean = jnp.inner(Lzz_inv_Kzx, m) # [M] @ [M] -> []

    return mean, variance 


@Partial(jax.jit, static_argnames=('jitter',))
def spherical_harmonic_features_moments(
    Kxx: Float[Array, ""], 
    Kxz: Float[Array, "M"], 
    Kzz_inv_diag: Float[Array, "M"], 
    m: Float[Array, "M"], 
    sqrtS: Float[Array, "M M"], 
    jitter: float = 1e-12
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)
    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS

    covariance = (
        Kxx
        + jnp.sum(jnp.square(Kxz_Lzz_T_inv_sqrtS))
        # + Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T
        # - Kxz_Lzz_T_inv @ Kxz_Lzz_T_inv.T
        # No need for the term above as it is absorbed into Kxx 
    )

    mean = (
        Kxz_Lzz_T_inv @ m
    )

    return mean, covariance


@jax.jit
def whitened_prior_kl(m: Float, sqrtS: Float) -> Float:
    S = sqrtS @ sqrtS.T
    qz = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=S)

    pz = tfd.MultivariateNormalFullCovariance(
        loc=jnp.zeros(m.shape), 
        covariance_matrix=jnp.eye(m.shape[0]),
    )
    return tfd.kl_divergence(qz, pz)


def inducing_points_prior_kl(m: Float, sqrtS: Float) -> Float:
    return whitened_prior_kl(m, sqrtS)


@dataclass 
class DummyPosterior(Module):
    prior: Prior = param_field()


@dataclass 
class MultioutputDummyPosterior(Module):
    prior: MultioutputPrior = param_field()

    @property 
    def num_outputs(self):
        return self.prior.num_outputs


@dataclass 
class InducingPointsPosterior(Module):
    posterior: Posterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)

    @property 
    def jitter(self):
        return self.posterior.prior.jitter

    def prior_kl(self) -> Float:
        return inducing_points_prior_kl(self.m, self.sqrtS)

    @jax.jit 
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "1"], Float[Array, "1"]]:
        kernel = self.posterior.prior.kernel
        z = self.z

        Kxx = kernel(x, x)
        Kzx = jax.vmap(lambda t: kernel(t, x))(z)
        Kzz = jax.vmap(lambda t1: jax.vmap(lambda t2: kernel(t1, t2))(z))(z)

        return inducing_points_moments(Kxx, Kzx, Kzz, self.m, self.sqrtS, jitter=self.jitter)

    @jax.jit
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    


@dataclass 
class MultioutputInducingPointsPosterior(Module):
    posterior: MultioutputPosterior = param_field()
    z: Float[Array, "M D"] = param_field(trainable=False)
    m: Float[Array, "M O"] = param_field(init=False)
    sqrtS: Float[Array, "M M O"] = param_field(init=False, bijector=tfb.FillTriangular())

    num_outputs: int = static_field(init=False)
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        self.num_outputs = self.posterior.num_outputs
        self.num_inducing = self.z.shape[0]

        self.m = jnp.zeros((self.num_outputs, self.num_inducing))
        self.sqrtS = jnp.repeat(jnp.expand_dims(jnp.eye(self.num_inducing), axis=0), self.num_outputs, axis=0)

    @jax.jit 
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(inducing_points_prior_kl)(self.m, self.sqrtS), axis=0)

    @jax.jit 
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "O 1"], Float[Array, "O 1"]]:
        kernel = self.posterior.prior.kernel 
        z = self.z

        Kxx = kernel(x, x)
        Kzx = jax.vmap(lambda t: kernel(t, x))(z)
        Kzz = jax.vmap(lambda t1: jax.vmap(lambda t2: kernel(t1, t2))(z))(z)

        mean, variance = jax.vmap(inducing_points_moments, in_axes=(0, 1, 2, 0, 0))(Kxx, Kzx, Kzz, self.m, self.sqrtS)

        return mean, variance

    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    

@dataclass
class SphericalHarmonicFeaturesPosterior(Module):
    posterior: Posterior = param_field()
    # spherical_harmonics: SphericalHarmonics = static_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)
        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "L"]) -> Float[Array, "M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jnp.repeat(
            spectrum[:shf.max_ell + 1], 
            repeats=repeats,
            total_repeat_length=total_repeat_length,
        )
    
    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    def prior_kl(self) -> Float[Array, ""]:
        return whitened_prior_kl(self.m, self.sqrtS)

    @jax.jit
    def moments(self, x: Float[Array, "N D"]) -> tuple[Float[Array, ""], Float[Array, ""]]:
        kernel = self.posterior.prior.kernel

        spectrum = kernel.spectrum()

        # This already accounts for the subtraction of the identity matrix from S'
        S_augment = jnp.square(self.sqrtS_augment)
        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x)
        Kzz_diag = self.Kzz_diag(spectrum)
        Kxz = self.Kxz(x)

        return spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, self.m, self.sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))


@dataclass
class MultioutputSphericalHarmonicFeaturesPosterior(Module):
    num_outputs: int = static_field(init=False)

    posterior: MultioutputPosterior = param_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_outputs = self.posterior.num_outputs
        
        num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(num_inducing)
        self.sqrtS = jnp.eye(num_inducing)
        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

        self.m = jnp.broadcast_to(self.m, (self.num_outputs, num_inducing))
        self.sqrtS = jnp.broadcast_to(self.sqrtS, (self.num_outputs, num_inducing, num_inducing))
        self.sqrtS_augment = jnp.broadcast_to(self.sqrtS_augment, (self.num_outputs, kernel.max_ell + 1))

    @jax.jit
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(whitened_prior_kl)(self.m, self.sqrtS), axis=0)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "O L"]) -> Float[Array, "O M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jax.vmap(
            lambda spectrum: jnp.repeat(spectrum, repeats=repeats, total_repeat_length=total_repeat_length)
        )(spectrum[:, :shf.max_ell + 1])
    

    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "O M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    
    @jax.jit
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "O"], Float[Array, "O"]]:
        kernel = self.posterior.prior.kernel

        # prior covariance adjusted by the diagonal variational parameters 
        spectrum = kernel.spectrum() # [O L]
        S_augment = jnp.square(self.sqrtS_augment) # [O L]
        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x) # [O N N]

        # variational covariance 
        Kzz_diag = self.Kzz_diag(spectrum) # [O M]
        Kxz = self.Kxz(x) # [O M]

        m = self.m
        sqrtS = self.sqrtS

        return jax.vmap(
            lambda Kxx, Kzz_diag, m, sqrtS: spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, m, sqrtS)
        )(Kxx, Kzz_diag, m, sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))


@jax.jit 
def expected_log_likelihood(y: Float, m: Float, f_var: Float, eps_var: Float) -> Float:
    log2pi = jnp.log(2 * jnp.pi)
    squared_error = jnp.square(y - m)
    return -0.5 * jnp.sum(log2pi + jnp.log(eps_var) + (squared_error + f_var) / eps_var, axis=-1)


@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


# TODO verify that this is correct 
@jax.jit
def sphere_expmap(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    theta = jnp.linalg.norm(v, axis=-1, keepdims=True)

    t = x + v
    first_order_approx = t / jnp.linalg.norm(t, axis=-1, keepdims=True)
    true_expmap = jnp.cos(theta) * x + jnp.sin(theta) * v / theta

    return jnp.where(
        theta < 1e-12,
        first_order_approx,
        true_expmap,
    )


@jax.jit 
def sphere_to_tangent(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    v_x = jnp.sum(x * v, axis=-1, keepdims=True)
    return v - v_x * x


@dataclass 
class SphereResidualDeepGP(Module):
    hidden_layers: list[MultioutputInducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> Posterior:
        return self.output_layer.posterior      
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()
    
    def sample_moments(self, x: Float[Array, "N D"], *, key: Key) -> tfd.Normal:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            u = sphere_to_tangent(x, v)
            x = sphere_expmap(x, u)
        return jax.vmap(self.output_layer.moments)(x)

    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MixtureSameFamily:
        sample_keys = jr.split(key, self.num_samples)

        # In MixtureSameFamily batch size goes last; hence, out_axes = 1
        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) 

        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), 
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), 
        )


@dataclass
class DeepGaussianLikelihood(Module):
    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    
    @jax.jit 
    def diag(self, pf: tfd.MixtureSameFamily) -> tfd.MixtureSameFamily:
        component_distribution = pf.components_distribution
        mean, variance = component_distribution.mean(), component_distribution.variance()
        variance += self.noise_variance
        return tfd.MixtureSameFamily(
            mixture_distribution=pf.mixture_distribution,
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)),
        )


@dataclass 
class EuclideanDeepGP(Module):
    hidden_layers: list[InducingPointsPosterior] = param_field()
    output_layer: InducingPointsPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> MultioutputPosterior:
        return self.output_layer.posterior        
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()

    def sample_moments(self, x: Float[Array, "N D"], *, key: Key) -> tfd.Normal:
        hidden_layer_keys = jr.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key) # [N D]
            x = x + v # euclidean expmap
        return jax.vmap(self.output_layer.moments)(x)
    
    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MixtureSameFamily:
        sample_keys = jr.split(key, self.num_samples)

        # In MixtureSameFamily batch size goes last; hence, out_axes = 1
        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) 

        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), 
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), 
        )
    

from typing import TypeVar

DeepGP = TypeVar("DeepGP", EuclideanDeepGP, SphereResidualDeepGP)
    

@jax.jit 
def negative_elbo(p: InducingPointsPosterior, x: Float, y: Float, *, key: Key) -> Float:
    pf_diag = p.diag(x, key=key)
    m, f_var = pf_diag.mean(), pf_diag.variance()
    eps_var = p.posterior.likelihood.noise_variance
    return -(expected_log_likelihood(y, m, f_var, eps_var) - p.prior_kl())


from scipy.cluster.vq import kmeans
from numpy.random import Generator


def jax_key_to_numpy_generator(key: Key) -> Generator:
    return np.random.default_rng(np.asarray(key._base_array))


def kmeans_inducing_points(key: Key, x: Float[Array, "N D"], num_inducing: int) -> Float[Array, "M D"]:
    seed = jax_key_to_numpy_generator(key)

    x = np.asarray(x, dtype=np.float64)
    n = x.shape[0]

    k_centers = []
    while num_inducing > n:
        k_centers.append(kmeans(x, n)[0])
        num_inducing -= n
    k_centers.append(kmeans(x, num_inducing, seed=seed)[0])
    k_centers = np.concatenate(k_centers, axis=0)
    return jnp.asarray(k_centers, dtype=jnp.float64)


def num_phases_to_num_levels(num_phases: int, *, sphere_dim: int) -> int:
    l = 0
    while num_phases > 0:
        num_phases -= num_phases_in_frequency(frequency=l, sphere_dim=sphere_dim)
        l += 1
    return l - 1 if num_phases == 0 else l - 2


def create_residual_deep_gp_with_spherical_harmonic_features(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, 
    nu: float = 1.5, kernel_max_ell: int | None = None
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = jnp.array(total_hidden_variance / max(num_layers - 1, 1))
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    shf_max_ell = num_phases_to_num_levels(num_inducing, sphere_dim=sphere_dim)
    if kernel_max_ell is None:
        kernel_max_ell = shf_max_ell
    hidden_spherical_harmonics = SphericalHarmonics(max_ell=shf_max_ell, sphere_dim=sphere_dim)
    output_spherical_harmonics = hidden_spherical_harmonics

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputSphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=hidden_spherical_harmonics)
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)


def create_euclidean_deep_gp_with_inducing_points(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, train_inducing: bool = True
) -> EuclideanDeepGP:
    sphere_dim = x.shape[1] - 1
    
    hidden_variance = total_hidden_variance / max(num_layers - 1, 1)
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.ones(sphere_dim + 1)
    output_kappa = hidden_kappa

    z = kmeans_inducing_points(key, x, num_inducing)
    hidden_z = z
    output_z = z

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputEuclideanMaternKernel32(
            variance=hidden_variance,
            kappa=hidden_kappa, 
            num_outputs=sphere_dim + 1, 
            num_inputs=sphere_dim + 1,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputInducingPointsPosterior(posterior=posterior, z=hidden_z)
        if train_inducing:
            layer = layer.replace_trainable(z=True)
        hidden_layers.append(layer)


    kernel = EuclideanMaternKernel32(
        num_inputs=sphere_dim + 1,
        variance=output_variance,
        kappa=output_kappa,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    # likelihood = GaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = InducingPointsPosterior(posterior=posterior, z=output_z)
    if train_inducing:
        output_layer = output_layer.replace_trainable(z=True)

    return EuclideanDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)

def create_model(
    num_layers: int, total_hidden_variance: float, num_inducing: int, x: Float[Array, "N D"], num_samples: int = 3, *, key: Key, name: str, nu: float = 1.5, 
    kernel_max_ell: int | None = None
) -> DeepGP: 
    if name == 'residual+inducing_points':
        return create_residual_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        ) 
    if name == 'euclidean+inducing_points':
        return create_euclidean_deep_gp_with_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'euclidean_with_geometric_input+inducing_points':
        return create_euclidean_deep_gp_with_input_geometric_layer_and_inducing_points(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key,
        )
    if name == 'residual+spherical_harmonic_features':
        return create_residual_deep_gp_with_spherical_harmonic_features(
            num_layers, total_hidden_variance, num_inducing, x, num_samples=num_samples, key=key, nu=nu, kernel_max_ell=kernel_max_ell
        )
    raise ValueError(f"Unknown model name: {name}")



# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from beartype.typing import (
    Any,
    Callable,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import jax
import jax.random as jr
import optax as ox

from gpjax.base import Module
from gpjax.dataset import Dataset
from gpjax.objectives import AbstractObjective
from gpjax.scan import vscan
from gpjax.typing import (
    Array,
    KeyArray,
    ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def time_fit(  # noqa: PLR0913
    *,
    model: ModuleModel,
    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
    x: Float, 
    y: Float,
    optim: ox.GradientTransformation,
    key: KeyArray,
    num_iters: Optional[int] = 100,
    batch_size: Optional[int] = None,
    verbose: Optional[bool] = True,
    unroll: Optional[int] = 1,
) -> float:
    import time
    
    r"""Train a Module model with respect to a supplied Objective function.
    Optimisers used here should originate from Optax.

    Example:
    ```python
        >>> import jax.numpy as jnp
        >>> import jax.random as jr
        >>> import optax as ox
        >>> import gpjax as gpx
        >>>
        >>> # (1) Create a dataset:
        >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
        >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
        >>> D = gpx.Dataset(X, y)
        >>>
        >>> # (2) Define your model:
        >>> class LinearModel(gpx.base.Module):
                weight: float = gpx.base.param_field()
                bias: float = gpx.base.param_field()

                def __call__(self, x):
                    return self.weight * x + self.bias

        >>> model = LinearModel(weight=1.0, bias=1.0)
        >>>
        >>> # (3) Define your loss function:
        >>> class MeanSquareError(gpx.objectives.AbstractObjective):
                def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
                    return jnp.mean((train_data.y - model(train_data.X)) ** 2)
        >>>
        >>> loss = MeanSqaureError()
        >>>
        >>> # (4) Train!
        >>> trained_model, history = gpx.fit(
                model=model, objective=loss, train_data=D, optim=ox.sgd(0.001), num_iters=1000
            )
    ```

    Args:
        model (Module): The model Module to be optimised.
        objective (Objective): The objective function that we are optimising with
            respect to.
        train_data (Dataset): The training data to be used for the optimisation.
        optim (GradientTransformation): The Optax optimiser that is to be used for
            learning a parameter set.
        num_iters (Optional[int]): The number of optimisation steps to run. Defaults
            to 100.
        batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
            (i.e. full batch).
        key (Optional[KeyArray]): The random key to use for the optimisation batch
            selection. Defaults to jr.key(42).
        log_rate (Optional[int]): How frequently the objective function's value should
            be printed. Defaults to 10.
        verbose (Optional[bool]): Whether to print the training loading bar. Defaults
            to True.
        unroll (int): The number of unrolled steps to use for the optimisation.
            Defaults to 1.

    Returns
    -------
        Tuple[Module, Array]: A Tuple comprising the optimised model and training
            history respectively.
    """

    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
    def loss(model: Module, x: Float, y: Float, *, key: Key) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(model.constrain(), x, y, key=key)

    # Unconstrained space model.
    model = model.unconstrain()

    # Initialise optimiser state.
    state = optim.init(model)

    # Mini-batch random keys to scan over.
    iter_keys = jr.split(key, num_iters)

    # Optimisation step.
    def step(carry, key):
        model, opt_state = carry

        if batch_size is not None:
            batch_x, batch_y = get_batch(x, y, batch_size, key)
        else:
            batch_x, batch_y = x, y

        loss_val, loss_gradient = jax.value_and_grad(loss)(model, batch_x, batch_y, key=key)
        updates, opt_state = optim.update(loss_gradient, opt_state, model)
        model = ox.apply_updates(model, updates)

        carry = model, opt_state
        return carry, loss_val

    # Optimisation scan.
    scan = vscan if verbose else jax.lax.scan

    # Optimisation loop.
    import time 
    start = time.time()
    scan(step, (model, state), (iter_keys), unroll=unroll)
    end = time.time()

    return end - start 


def get_batch(x: Float, y: Float, batch_size: int, key: KeyArray) -> tuple[Float, Float]:
    """Batch the data into mini-batches. Sampling is done with replacement.

    Args:
        train_data (Dataset): The training dataset.
        batch_size (int): The batch size.
        key (KeyArray): The random key to use for the batch selection.

    Returns
    -------
        Dataset: The batched dataset.
    """
    n = x.shape[0]

    # Subsample mini-batch indices with replacement.
    indices = jr.choice(key, n, (batch_size,), replace=True)

    return x[indices], y[indices]


@Partial(jax.jit, static_argnames=('n',))
def deep_negative_elbo(p: EuclideanDeepGP, x: Float, y: Float, *, key: Key, n: int) -> Float:
    eps_var = p.posterior.likelihood.noise_variance
    sample_keys = jr.split(key, p.num_samples)

    def sample_expected_log_likelihood(key: Key) -> Float:
        m, f_var = p.sample_moments(x, key=key)
        return expected_log_likelihood(y, m, f_var, eps_var)
    
    deep_expected_log_likelihood = jnp.mean(jax.vmap(sample_expected_log_likelihood)(sample_keys), axis=0)
    batch_ratio_correction = n / x.shape[0]

    return -(deep_expected_log_likelihood * batch_ratio_correction - p.prior_kl())


import timeit 
def time_function(func, globals_dict=None, number=1, repeat=100):
    assert isinstance(func, str)
    assert func.startswith("jax.block_until_ready(") or func.endswith(".block_until_ready()")

    if globals_dict is None:
        globals_dict = {}
    timer = timeit.Timer(func, globals=globals_dict)
    timer.timeit(number=1)
    return timer.repeat(repeat=repeat, number=number)



from enum import Enum 
from abc import ABC 
from dataclasses import field


def sphere_uniform(sphere_dim: int, n: int, *, key: Key) -> Float[Array, "N D"]:
    x = jax.random.normal(key, (n, sphere_dim + 1))
    return x / jnp.linalg.norm(x, axis=-1, keepdims=True)


@dataclass 
class Dataset(ABC):
    name: str 
    dim: int
    num: int 
    num_inducing: int
    kernel_max_ell: int
    x: Float[Array, "N D"] = field(init=False)
    batch_size: int | None = field(default=None)

    def __post_init__(self):
        if self.batch_size is None:
            self.batch_size = int(self.num * 0.9)
        self.x_train = sphere_uniform(self.dim, self.batch_size, key=jax.random.key(0))
        self.y_train = jnp.zeros(self.batch_size)
    

@dataclass
class Yacht(Dataset):
    name: str = "yacht"
    dim: int = 6
    num: int = 308
    num_inducing: int = 294
    kernel_max_ell: int = 12


@dataclass
class Energy(Dataset):
    name: str = "energy"
    dim: int = 8
    num: int = 768
    num_inducing: int = 210
    kernel_max_ell: int = 10


@dataclass
class Concrete(Dataset):
    name: str = "concrete"
    dim: int = 8
    num: int = 1030
    num_inducing: int = 294
    kernel_max_ell: int = 12


@dataclass
class Kin8mn(Dataset):
    name: str = "kin8mn"
    dim: int = 8
    num: int = 8192
    num_inducing: int = 210
    kernel_max_ell: int = 10
    batch_size: int = 1000


@dataclass
class Power(Dataset):
    name: str = "power"
    dim: int = 4
    num: int = 9568
    num_inducing: int = 336
    kernel_max_ell: int = 20
    batch_size: int = 1000


class Datasets(Enum):
    yacht = Yacht()
    concrete = Concrete()
    energy = Energy()
    kin8mn = Kin8mn()
    power = Power()


dataset_names = [
    "yacht",
    "concrete",
    "energy",
    "kin8mn",
    "power",
]


model_names = [
    "residual+spherical_harmonic_features", 
    "euclidean+inducing_points", 
]


nums_layers = [
    1, 2, 3, 4, 5,
]

total_hidden_variance = 0.0001
train_num_samples = 3

lr = 0.01
num_iters = 10



if __name__ == "__main__":
    import argparse 
    import os 

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="yacht")
    parser.add_argument("--model", type=str, default="residual+spherical_harmonic_features")
    parser.add_argument("--num_layers", type=int, default=1)
    parser.add_argument("--num_iters", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    dataset = args.dataset
    num_iters = args.num_iters
    num_layers = args.num_layers
    model_name = args.model
    seed = args.seed


    key = jax.random.key(seed)

    dataset = Datasets[dataset].value

    settings = {
        "dataset": dataset.name,
        "model": model_name,
        "num_layers": num_layers,
        "dataset_dim": dataset.dim,
        "batch_size": dataset.batch_size,
        "num_iters": num_iters,
        "num_inducing": dataset.num_inducing,
        "seed": seed,
    }
    print(settings)


    x_train = dataset.x_train
    y_train = dataset.y_train
    model = create_model(
        num_layers, total_hidden_variance, dataset.num_inducing, x_train, num_samples=train_num_samples, key=jax.random.key(0), name=model_name
    )
    optim = ox.adam(lr)
    objective = Partial(deep_negative_elbo, n=dataset.batch_size)

    def loss(model: Module, x: Float, y: Float, *, key: Key) -> ScalarFloat:
        model = model.stop_gradient()
        return objective(model, x, y, key=key)
    grad_loss = jax.grad(loss)
    jit_grad_loss = jax.jit(grad_loss)

    # time 
    func = "jax.block_until_ready(jit_grad_loss(model, x_train, y_train, key=key))"
    times = time_function(func, globals_dict=locals(), number=1, repeat=num_iters)


    # save results 
    experiment_dir = Path("results/time_uci") / "-".join([f"{k}={v}" for k, v in settings.items()])
    os.makedirs(experiment_dir, exist_ok=True)
    result_path = experiment_dir / "results.csv"
    result = settings | {"time": times}
    pd.DataFrame([result]).explode('time').to_csv(result_path, index=False)
