"""1D elliptic equation -f'' + f = u with GP prior on u (https://proceedings.mlr.press/v178/weissmann22a.html Example 2.1)"""

from typing import Callable
from functools import partial
import numpy as np


def _score_fn(x, scale_x, scale_y, y, prior_cov_scale):
    # x.shape = (n_particles, Nx) or (Nx,)
    # scale_x.shape = (Nx, Nx)
    # scale_y.shape = (Nx, Ny)
    # y.shape = (Ny,)
    # output shape = (n_particles, Nx) or (Nx,)
    return (
        -x @ scale_x
        + scale_y @ y
        - x * np.square(np.arange(1, x.shape[-1] + 1)) / prior_cov_scale
    )


def _log_pi(x, y, G, noise_cov, prior_cov_scale):
    return -0.5 * (
        np.dot(y - G @ x, np.linalg.solve(noise_cov, y - G @ x))
        + np.dot(x, x * np.square(np.arange(1, x.shape[-1] + 1)) / prior_cov_scale)
    )


def generate_example(
    rng: np.random.Generator,
    Nx: int,
    Ny: int,
    No: int,
    noise_cov: np.ndarray,
    N_posterior: int = 1000,
    prior_cov_scale: float = 1.0,
    KL_length: int = 256,
) -> tuple[
    np.ndarray,
    Callable[[np.ndarray], np.ndarray],
    Callable[[np.ndarray], np.ndarray],
    np.ndarray,
    np.ndarray,
    Callable[[np.ndarray | float, np.ndarray], np.ndarray | float],
    np.ndarray,
    np.ndarray,
    np.ndarray,
]:
    """1D elliptic equation -f'' + f = u with GP prior on u (https://arxiv.org/pdf/2312.13889)"""
    # build the linear operator G = O*H*A, where
    # A (Ny x Nx) contains the Ny evaluations of the first Nx basis functions of the KL-expansion
    # H (Ny x Ny) is the finite difference discretization (with grid size 1/(Ny+1)) of the negative Laplacian + identity
    # O (No x Ny) is the observation operator (point evaluations)
    KL_length = max(Nx, KL_length)
    A = np.empty((Ny, KL_length))
    for k in range(1, KL_length + 1):
        A[:, k - 1] = np.sqrt(2) * np.sin(np.pi * k * np.arange(1, Ny + 1) / Ny)
    H = (Ny + 1) ** 2 * (-np.eye(Ny, k=-1) + 2 * np.eye(Ny) - np.eye(Ny, k=1))
    H += np.eye(Ny)
    O = np.zeros((No, Ny))  # noqa: E741
    obs_indices = np.linspace(0, Ny - 1, No, dtype=int)
    for i, idx in enumerate(obs_indices):
        O[i, idx] = 1.0
    G = O @ np.linalg.solve(H, A)  # use this to compute observations

    # generate reference data
    x_reference = rng.multivariate_normal(
        mean=np.zeros(KL_length),
        cov=np.diag(np.power(np.arange(1, KL_length + 1), -2.0)) * prior_cov_scale,
    )
    y = G @ x_reference
    true_process = A @ x_reference
    G = O @ np.linalg.solve(
        H, A[:, :Nx]
    )  # use this in the inference (only estimate Nx KL coefficients)

    def fn(s, x: np.ndarray) -> np.ndarray:
        # s.shape = (no_grid_points,) or ()
        # x.shape = (Nx,) or (n_particles, Nx)
        # out.shape = x.shape[:-1] + s.shape
        out = np.zeros(x.shape[:-1] + np.array(s).shape, dtype=np.float64)
        for k, xk in enumerate(x.T, start=1):
            # xk.shape = (n_particles,) or ()
            out += np.outer(xk, np.sqrt(2) * np.sin(np.pi * k * s)).squeeze()
        return out

    # generate posterior samples
    prior_cov = np.diag(np.power(np.arange(1, Nx + 1), -2.0)) * prior_cov_scale
    posterior_mean = (
        prior_cov @ G.T @ np.linalg.inv(G @ prior_cov @ G.T + noise_cov) @ y
    )
    posterior_cov = (
        prior_cov
        - prior_cov
        @ G.T
        @ np.linalg.inv(G @ prior_cov @ G.T + noise_cov)
        @ G
        @ prior_cov
    )
    posterior_sample = rng.multivariate_normal(
        mean=posterior_mean, cov=posterior_cov, size=N_posterior
    )

    return (
        prior_cov,
        partial(
            _score_fn,
            scale_x=G.T @ np.linalg.solve(noise_cov, G),
            scale_y=G.T @ np.linalg.inv(noise_cov),
            y=y,
            prior_cov_scale=prior_cov_scale,
        ),
        partial(
            _log_pi,
            y=y,
            G=G,
            noise_cov=noise_cov,
            prior_cov_scale=prior_cov_scale,
        ),
        y,
        true_process,
        fn,
        posterior_sample,
        posterior_cov,
        posterior_mean,
    )
