import typing

import equinox
import jax
import scipy.stats
from jax import numpy as jnp


class Normal(equinox.Module):
    μ: jnp.ndarray
    Σ: jnp.ndarray
    n: int

    @staticmethod
    def standard(n):
        return Normal(μ=jnp.zeros(n), Σ=jnp.eye(n))

    @staticmethod
    def certain(μ):
        return Normal(μ=μ, Σ=jnp.zeros((μ.shape[0], μ.shape[0]), dtype=int))

    def __init__(self, μ, Σ):
        self.μ = μ
        self.Σ = Σ
        self.n = μ.shape[0]
        assert self.Σ.shape == (self.n, self.n), self.Σ

    def check_finite(self):
        if not jnp.all(jnp.isfinite(self.μ)):
            raise ValueError(f"μ contains NaNs: {self.μ}")
        if not jnp.all(jnp.isfinite(self.Σ)):
            raise ValueError(f"Σ contains NaNs: {self.Σ}")
        return self

    def check_psd(self):
        if not jnp.all(jnp.linalg.eigvalsh(self.Σ) >= 0):
            raise ValueError(f"Σ is not PSD: {self.Σ}")
        return self

    def mean_field(self):
        return Normal(μ=self.μ, Σ=jnp.diag(jnp.diag(self.Σ)))

    def qmc(self, num_samples, seed=42):
        return scipy.stats.qmc.MultivariateNormalQMC(
            mean=self.μ,
            cov=self.Σ,
            rng=seed,
            engine=scipy.stats.qmc.Sobol(rng=seed, scramble=True, d=self.n),
        ).random(num_samples)

    def samples(self, num_samples, key=jax.random.PRNGKey(42)):
        # uses the svd method to support degenerate covariance matrices
        return jax.random.multivariate_normal(
            key, mean=self.μ, cov=self.Σ, shape=num_samples, method="svd"
        )

    def pdf(self, x):
        return jax.scipy.stats.multivariate_normal.pdf(x, mean=self.μ, cov=self.Σ)

    def lpdf(self, x):
        return jax.scipy.stats.multivariate_normal.logpdf(x, mean=self.μ, cov=self.Σ)

    @staticmethod
    def independent(*normals: "Normal") -> "Normal":
        """Creates a joint distribution with zero correlations from multiple Normal distributions."""
        μ = jnp.concatenate([normal.μ for normal in normals])
        Σ_blocks = [normal.Σ for normal in normals]
        Σ = jax.scipy.linalg.block_diag(*Σ_blocks)
        return Normal(μ, Σ)

    def add_covariance(self, cov, at=slice(None, None)):
        return Normal(self.μ, self.Σ.at[at, at].add(cov))

    @equinox.filter_jit
    def __getitem__(self, index: typing.Union[int, slice]):
        """Return the marginal distribution for the specified index."""
        if isinstance(index, int):
            index = slice(index, index + 1)
        if isinstance(index, slice):
            return Normal(self.μ[index], self.Σ[index, index])
        else:
            raise ValueError

    def delete(self, index: int):
        return Normal(
            jnp.delete(self.μ, index),
            jnp.delete(jnp.delete(self.Σ, index, 0), index, 1),
        )

    def condition_on_projection(self, P: jnp.ndarray, y: jnp.ndarray):
        """Return the conditional distribution given that P X = y."""
        return Normal(
            μ=self.μ
            + self.Σ @ P.T @ jnp.linalg.inv(P @ self.Σ @ P.T) @ (y - P @ self.μ),
            Σ=self.Σ - self.Σ @ P.T @ jnp.linalg.inv(P @ self.Σ @ P.T) @ P @ self.Σ,
        )

    def condition(
        self, target: slice, given: slice, equals: typing.Union[jnp.ndarray, "Normal"]
    ):
        if isinstance(equals, Normal):
            μ, Σ = schur_complement_2(
                A=self.Σ[target, :][:, target],
                B=self.Σ[target, :][:, given],
                C=self.Σ[given, :][:, given],
                D=equals.Σ,
                x=self.μ[target],
                y=equals.μ - self.μ[given],
            )
        else:
            μ, Σ = schur_complement(
                A=self.Σ[target, :][:, target],
                B=self.Σ[target, :][:, given],
                C=self.Σ[given, :][:, given],
                x=self.μ[target],
                y=equals - self.μ[given],
            )
        return Normal(μ, Σ)

    def χ2(self, x):
        diff = x - self.μ
        return diff.T @ jnp.linalg.lstsq(self.Σ, diff)[0]

    def rectify(self):
        return Normal(self.μ, rectify_eigenvalues(self.Σ))

    @staticmethod
    def from_samples(samples):
        return Normal(
            jnp.atleast_1d(jnp.mean(samples, axis=0)),
            jnp.atleast_2d(jnp.cov(samples, rowvar=False, ddof=1)),
        )

    @equinox.filter_jit
    def kl_divergence(self, other: "Normal") -> float:
        """Compute the Kullback-Leibler divergence D(other || self)."""
        if self is other:
            return 0

        trace = jnp.trace(jnp.linalg.solve(self.Σ, other.Σ))
        log_det = jnp.log(jnp.linalg.det(self.Σ) / jnp.linalg.det(other.Σ))
        mean_diff = self.μ - other.μ
        return 0.5 * (
            trace + log_det + mean_diff.T @ jnp.linalg.solve(self.Σ, mean_diff) - self.n
        )

    def __add__(self, other):
        return Normal(self.μ + other.μ, self.Σ + other.Σ)

    def __mul__(self, a):
        assert jax.numpy.isscalar(a), "a must be a scalar"
        return Normal(self.μ * a, self.Σ * (a**2))

    def __rmul__(self, a):
        return self * a

    def __str__(self):
        return f"Normal(μ={self.μ}, Σ={self.Σ})"


@equinox.filter_jit
def schur_complement(A, B, C, x, y):
    """Returns a numerically stable evaluation of
    x + B C^(-1) y,
    A - B C^(-1) B^T.
    """
    # C = U^T U
    U = jax.scipy.linalg.cholesky(C)
    # B_tilde = B U^-T
    B_tilde = jax.scipy.linalg.solve_triangular(U, B.T, trans=1, lower=False).T
    return (
        x + B_tilde @ jax.scipy.linalg.solve_triangular(U, y, lower=False),
        A - B_tilde.dot(B_tilde.T),
    )


@equinox.filter_jit
def schur_complement_2(A, B, C, D, x, y):
    """Returns a numerically decent evaluation of
    x + G y,
    A + G (D - C) G^T.
    where G = B C^(-1).
    """
    G = jnp.linalg.solve(
        C,
        B.T,
    ).T
    return x + G @ (y), (A + G @ (D - C) @ G.T)


def rectify_eigenvalues(P):
    Λ, V = jnp.linalg.eigh(P, symmetrize_input=1)
    return V @ jnp.diag(jnp.maximum(Λ, 1e-8)) @ V.T
