from typing import Callable, Dict, Tuple
import jax.numpy as jnp

ndarray = jnp.ndarray
jitter = 1e-6

def predict(
    prior: Dict,
    params: Dict,
    train_data: Tuple[ndarray, ndarray],
    diag: bool = False,
) -> Callable[[ndarray], Dict[str, ndarray]]:
    """
    Compute the GP posterior predictive distribution.

    Args:
        prior (Dict): A dictionary with 'mean_function' and 'kernel' keys.
        params (Dict): A dictionary with 'mean_function', 'kernel', 'noise_variance'.
        train_data (Tuple): Training data (x, y).
        diag (bool): Whether to return only the diagonal of the covariance.

    Returns:
        Callable: A function that maps test inputs to predictive distribution (mean, cov).
    """
    x, y = train_data
    n = x.shape[0]

    mean_function = prior["mean_function"]
    kernel = prior["kernel"]

    obs_noise = params["noise_variance"]

    μx = mean_function(params["mean_function"], x).reshape(-1, 1)
    Kxx = kernel.gram(params["kernel"], x) + jnp.eye(n) * (jitter + obs_noise)
    
    L = jnp.linalg.cholesky(Kxx)
    alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, y - μx))

    def predict_fn(test_inputs: ndarray) -> Dict[str, ndarray]:
        t = test_inputs
        n_test = t.shape[0]

        μt = mean_function(params["mean_function"], t).reshape(-1, 1)
        Ktt = kernel.gram(params["kernel"], t)
        Kxt = kernel.cross_covariance(params["kernel"], x, t)

        mean = μt + Kxt.T @ alpha

        v = jnp.linalg.solve(L, Kxt)
        covariance = Ktt - v.T @ v

        if diag:
            return {
                "mean": mean.squeeze(),
                "diag_cov": jnp.diag(covariance)
            }
        else:
            covariance += jnp.eye(n_test) * jitter
            return {
                "mean": mean.squeeze(),
                "cov": covariance
            }

    return predict_fn
