# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from functools import partial, update_wrapper
import math
import warnings

import numpy as np

import jax
from jax import jit, lax, random, vmap
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma

# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple(
    "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
)


def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = jnp.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = jnp.floor((n + 1) * p).astype(n.dtype)
    log_p = jnp.log(p)
    log1_p = jnp.log1p(-p)
    log_h = (m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + (
        stirling_approx_tail(m) + stirling_approx_tail(n - m)
    )
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)


def stirling_approx_tail(k):
    precomputed = jnp.array(
        [
            0.08106146679532726,
            0.04134069595540929,
            0.02767792568499834,
            0.02079067210376509,
            0.01664469118982119,
            0.01387612882307075,
            0.01189670994589177,
            0.01041126526197209,
            0.009255462182712733,
            0.008330563433362871,
        ]
    )
    kp1 = k + 1
    kp1sq = (k + 1) ** 2
    return jnp.where(
        k < 10,
        precomputed[k],
        (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1,
    )


_binomial_mu_thresh = 10


def _binomial_btrs(key, p, n):
    """
    Based on the transformed rejection sampling algorithm (BTRS) from the
    following reference:

    Hormann, "The Generation of Binonmial Random Variates"
    (https://core.ac.uk/download/pdf/11007254.pdf)
    """

    def _btrs_body_fn(val):
        _, key, _, _ = val
        key, key_u, key_v = random.split(key, 3)
        u = random.uniform(key_u)
        v = random.uniform(key_v)
        u = u - 0.5
        k = jnp.floor(
            (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
        ).astype(n.dtype)
        return k, key, u, v

    def _btrs_cond_fn(val):
        def accept_fn(k, u, v):
            # See acceptance condition in Step 3. (Page 3) of TRS algorithm
            # v <= f(k) * g_grad(u) / alpha

            m = tr_params.m
            log_p = tr_params.log_p
            log1_p = tr_params.log1_p
            # See: formula for log(f(k)) at bottom of Page 5.
            log_f = (
                (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
                + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
                + (stirling_approx_tail(k) - stirling_approx_tail(n - k))
                + tr_params.log_h
            )
            g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
            return jnp.log((v * tr_params.alpha) / g) <= log_f

        k, key, u, v = val
        early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
        early_reject = (k < 0) | (k > n)
        # when vmapped _binomial_dispatch will convert the cond condition into
        # a HLO select that will execute both branches. This is a workaround
        # that avoids the resulting infinite loop when p=0. This should also
        # improve performance in less catastrophic cases.
        cond_exclude_small_mu = p * n >= _binomial_mu_thresh
        cond_main = lax.cond(
            early_accept | early_reject,
            (),
            lambda _: ~early_accept,
            (k, u, v),
            lambda x: ~accept_fn(*x),
        )
        return cond_exclude_small_mu & cond_main

    tr_params = _get_tr_params(n, p)
    ret = lax.while_loop(
        _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
    )  # use k=-1 initially so that cond_fn returns True
    return ret[0]


def _binomial_inversion(key, p, n):
    def _binom_inv_body_fn(val):
        i, key, geom_acc = val
        key, key_u = random.split(key)
        u = random.uniform(key_u)
        geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
        geom_acc = geom_acc + geom
        return i + 1, key, geom_acc

    def _binom_inv_cond_fn(val):
        i, _, geom_acc = val
        # see the note on cond_exclude_small_mu in _binomial_btrs
        # this cond_exclude_large_mu is unnecessary for correctness but will
        # still improve performance.
        cond_exclude_large_mu = p * n < _binomial_mu_thresh
        return cond_exclude_large_mu & (geom_acc <= n)

    log1_p = jnp.log1p(-p)
    ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
    return ret[0]


def _binomial_dispatch(key, p, n):
    def dispatch(key, p, n):
        is_le_mid = p <= 0.5
        pq = jnp.where(is_le_mid, p, 1 - p)
        mu = n * pq
        k = lax.cond(
            mu < _binomial_mu_thresh,
            (key, pq, n),
            lambda x: _binomial_inversion(*x),
            (key, pq, n),
            lambda x: _binomial_btrs(*x),
        )
        return jnp.where(is_le_mid, k, n - k)

    # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
    cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
    return lax.cond(
        cond0 & (p < 1),
        (key, p, n),
        lambda x: dispatch(*x),
        (),
        lambda _: jnp.where(cond0, n, 0),
    )


@partial(jit, static_argnums=(3,))
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if jax.default_backend() == "cpu":
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape)


def binomial(key, p, n=1, shape=()):
    return _binomial(key, p, n, shape)


@partial(jit, static_argnums=(2,))
def _categorical(key, p, shape):
    # this implementation is fast when event shape is small, and slow otherwise
    # Ref: https://stackoverflow.com/a/34190035
    shape = shape or p.shape[:-1]
    s = jnp.cumsum(p, axis=-1)
    # Normalize s to deal with numerical issues.
    s = s[..., :-1] / s[..., -1:]
    r = random.uniform(key, shape=shape + (1,))
    # FIXME: replace this computation by using binary search as suggested in the above
    # reference. A while_loop + vmap for a reshaped 2D array would be enough.
    return jnp.sum(s < r, axis=-1)


def categorical(key, p, shape=()):
    return _categorical(key, p, shape)


def _scatter_add_one(operand, indices, updates):
    return lax.scatter_add(
        operand,
        indices,
        updates,
        lax.ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,),
        ),
    )


@partial(jit, static_argnums=(3, 4))
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    if n_max == 0:
        return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(
            jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,)
        )[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate(
            [
                jnp.expand_dims(n_max - n, -1),
                jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,)),
            ],
            -1,
        )
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(
        jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
        jnp.expand_dims(indices_2D, axis=-1),
        jnp.ones(indices_2D.shape, dtype=indices.dtype),
    )
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess


def multinomial(key, p, n, shape=(), total_count_max=None):
    if total_count_max is None:
        if isinstance(n, jax.core.Tracer):
            raise ValueError(
                "Please specify total_count_max in Multinomial distribution."
            )
        n_max = int(np.max(jax.device_get(n)))
    else:
        n_max = total_count_max
    return _multinomial(key, p, n, n_max, shape)


def cholesky_of_inverse(matrix):
    # This formulation only takes the inverse of a triangular matrix
    # which is more numerically stable.
    # Refer to:
    # https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
    tril_inv = jnp.swapaxes(
        jnp.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1
    )
    identity = jnp.broadcast_to(jnp.identity(matrix.shape[-1]), tril_inv.shape)
    return solve_triangular(tril_inv, identity, lower=True)


# TODO: move upstream to jax.nn
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y


def _reshape(x, shape):
    if isinstance(x, (int, float, np.ndarray, np.generic)):
        return np.reshape(x, shape)
    else:
        return jnp.reshape(x, shape)


def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [jnp.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [
            _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
            for arg, s in zip(args, shapes)
        ]


def sum_rightmost(x, dim):
    """
    Sum out ``dim`` many rightmost dimensions of a given tensor.
    """
    out_dim = jnp.ndim(x) - dim
    x = jnp.reshape(jnp.expand_dims(x, -1), jnp.shape(x)[:out_dim] + (-1,))
    return jnp.sum(x, axis=-1)


def matrix_to_tril_vec(x, diagonal=0):
    idxs = jnp.tril_indices(x.shape[-1], diagonal)
    return x[..., idxs[0], idxs[1]]


def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(
        jnp.zeros(t.shape[:-1] + (n2,)),
        jnp.expand_dims(idx, axis=-1),
        t,
        lax.ScatterDimensionNumbers(
            update_window_dims=range(t.ndim - 1),
            inserted_window_dims=(t.ndim - 1,),
            scatter_dims_to_operand_dims=(t.ndim - 1,),
        ),
    )
    return jnp.reshape(x, x.shape[:-1] + (n, n))


def cholesky_update(L, x, coef=1):
    """
    Finds cholesky of L @ L.T + coef * x @ x.T.

    **References;**

        1. A more efficient rank-one covariance matrix update for evolution strategies,
           Oswin Krause and Christian Igel
    """
    batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
    L = jnp.broadcast_to(L, batch_shape + L.shape[-2:])
    x = jnp.broadcast_to(x, batch_shape + x.shape[-1:])
    diag = jnp.diagonal(L, axis1=-2, axis2=-1)
    # convert to unit diagonal triangular matrix: L @ D @ T.t
    L = L / diag[..., None, :]
    D = jnp.square(diag)

    def scan_fn(carry, val):
        b, w = carry
        j, Dj, L_j = val
        wj = w[..., j]
        gamma = b * Dj + coef * jnp.square(wj)
        Dj_new = gamma / b
        b = gamma / Dj_new

        # update vectors w and L_j
        w = w - wj[..., None] * L_j
        L_j = L_j + (coef * wj / gamma)[..., None] * w
        return (b, w), (Dj_new, L_j)

    D, L = jnp.moveaxis(D, -1, 0), jnp.moveaxis(L, -1, 0)  # move scan dim to front
    _, (D, L) = lax.scan(
        scan_fn, (jnp.ones(batch_shape), x), (jnp.arange(D.shape[0]), D, L)
    )
    D, L = jnp.moveaxis(D, 0, -1), jnp.moveaxis(L, 0, -1)  # move scan dim back
    return L * jnp.sqrt(D)[..., None, :]


def signed_stick_breaking_tril(t):
    # make sure that t in (-1, 1)
    eps = jnp.finfo(t.dtype).eps
    t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
    # transform t to tril matrix with identity diagonal
    r = vec_to_tril_matrix(t, diagonal=-1)

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r**2
    z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1)

    pad_width = [(0, 0)] * z.ndim
    pad_width[-1] = (1, 0)
    z1m_cumprod_sqrt_shifted = jnp.pad(
        z1m_cumprod_sqrt[..., :-1], pad_width, mode="constant", constant_values=1.0
    )
    y = add_diag(r, 1) * z1m_cumprod_sqrt_shifted
    return y


def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift


def clamp_probs(probs):
    finfo = jnp.finfo(jnp.result_type(probs, float))
    return jnp.clip(probs, a_min=finfo.tiny, a_max=1.0 - finfo.eps)


def betainc(a, b, x):
    try:
        from tensorflow_probability.substrates.jax.math import betainc as betainc_fn
    except ImportError:
        from jax.scipy.special import betainc as betainc_fn

    dtype = jnp.result_type(float)
    return betainc_fn(
        jnp.array(a, dtype=dtype),
        jnp.array(b, dtype=dtype),
        jnp.array(x, dtype=dtype),
    )


def betaincinv(a, b, y):
    try:
        from tensorflow_probability.substrates.jax.math import special as tfp_special
    except ImportError as e:
        raise ImportError(
            "Please install `tensorflow_probability>=0.18` for betaincinv."
        ) from e

    dtype = jnp.result_type(float)
    return tfp_special.betaincinv(
        jnp.array(a, dtype=dtype),
        jnp.array(b, dtype=dtype),
        jnp.array(y, dtype=dtype),
    )


def gammaincinv(a, y):
    try:
        from tensorflow_probability.substrates.jax import math as tfp_math

        return tfp_math.igammainv(jnp.array(a), jnp.array(y))
    except ImportError as e:
        raise ImportError(
            "Please install `tensorflow_probability>=0.18` for gammaincinv."
        ) from e


def is_identically_zero(x):
    """
    Check if argument is exactly the number zero. True for the number zero;
    false for other numbers; false for ndarrays.
    """
    if isinstance(x, (int, float)):
        return x == 0
    else:
        return False


def is_identically_one(x):
    """
    Check if argument is exactly the number one. True for the number one;
    false for other numbers; false for ndarrays.
    """
    if isinstance(x, (int, float)):
        return x == 1
    else:
        return False


def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
    """Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal.

    *** References ***
    [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
        Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf


    :param key: random number generator key
    :param concentration: concentration of distribution
    :param shape: shape of samples
    :param dtype: float precesions for choosing correct s cutfoff
    :return: centered samples from von Mises
    """
    shape = shape or jnp.shape(concentration)
    dtype = jnp.result_type(dtype)
    concentration = lax.convert_element_type(concentration, dtype)
    concentration = jnp.broadcast_to(concentration, shape)
    return _von_mises_centered(key, concentration, shape, dtype)


@partial(jit, static_argnums=(2, 3))
def _von_mises_centered(key, concentration, shape, dtype):
    # Cutoff from TensorFlow probability
    # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
    s_cutoff_map = {
        jnp.dtype(jnp.float16): 1.8e-1,
        jnp.dtype(jnp.float32): 2e-2,
        jnp.dtype(jnp.float64): 1.2e-4,
    }
    s_cutoff = s_cutoff_map.get(dtype)

    r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2)
    rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
    s_exact = (1.0 + rho**2) / (2.0 * rho)

    s_approximate = 1.0 / concentration

    s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

    def cond_fn(*args):
        """check if all are done or reached max number of iterations"""
        i, _, done, _, _ = args[0]
        return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

    def body_fn(*args):
        i, key, done, _, w = args[0]
        uni_ukey, uni_vkey, key = random.split(key, 3)

        u = random.uniform(
            key=uni_ukey,
            shape=shape,
            dtype=concentration.dtype,
            minval=-1.0,
            maxval=1.0,
        )
        z = jnp.cos(jnp.pi * u)
        w = jnp.where(done, w, (1.0 + s * z) / (s + z))  # Update where not done

        y = concentration * (s - w)
        v = random.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)

        accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)

        return i + 1, key, accept | done, u, w

    init_done = jnp.zeros(shape, dtype=bool)
    init_u = jnp.zeros(shape)
    init_w = jnp.zeros(shape)

    _, _, done, u, w = lax.while_loop(
        cond_fun=cond_fn,
        body_fun=body_fn,
        init_val=(jnp.array(0), key, init_done, init_u, init_w),
    )

    return jnp.sign(u) * jnp.arccos(w)


def scale_and_mask(x, scale=None, mask=None):
    """
    Scale and mask a tensor, broadcasting and avoiding unnecessary ops.
    """
    if is_identically_zero(x):
        return x
    if not (scale is None or is_identically_one(scale)):
        x = x * scale
    if mask is None:
        return x
    else:
        return jnp.where(mask, x, 0.0)


# TODO: use funsor implementation
def periodic_repeat(x, size, dim):
    """
    Repeat a ``period``-sized array up to given ``size``.
    """
    assert isinstance(size, int) and size >= 0
    assert isinstance(dim, int)
    if dim >= 0:
        dim -= jnp.ndim(x)

    period = jnp.shape(x)[dim]
    repeats = (size + period - 1) // period
    result = jnp.repeat(x, repeats, axis=dim)
    result = result[(Ellipsis, slice(None, size)) + (slice(None),) * (-1 - dim)]
    return result


def safe_normalize(x, *, p=2):
    """
    Safely project a vector onto the sphere wrt the ``p``-norm. This avoids the
    singularity at zero by mapping zero to the uniform unit vector proportional
    to ``[1, 1, ..., 1]``.

    :param numpy.ndarray x: A vector
    :param float p: The norm exponent, defaults to 2 i.e. the Euclidean norm.
    :returns: A normalized version ``x / ||x||_p``.
    :rtype: numpy.ndarray
    """
    assert isinstance(p, (float, int))
    assert p >= 0
    norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True)
    x = x / jnp.clip(norm, a_min=jnp.finfo(x).tiny)
    # Avoid the singularity.
    mask = jnp.all(x == 0, axis=-1, keepdims=True)
    x = jnp.where(mask, x.shape[-1] ** (-1 / p), x)
    return x


def is_prng_key(key):
    warnings.warn("Please use numpyro.util.is_prng_key.", DeprecationWarning)
    try:
        if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
            return key.shape == ()
        return key.shape == (2,) and key.dtype == np.uint32
    except AttributeError:
        return False


def assert_one_of(**kwargs):
    """
    Assert that exactly one of the keyword arguments is not None.
    """
    specified = [key for key, value in kwargs.items() if value is not None]
    if len(specified) != 1:
        raise ValueError(
            f"Exactly one of {list(kwargs)} must be specified; got {specified}."
        )


def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray:
    """
    Derivative of the log of multivariate gamma.
    """
    return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1)


def tri_logabsdet(a: jnp.ndarray) -> jnp.ndarray:
    """
    Evaluate the `logabsdet` of a triangular positive-definite matrix.
    """
    return jnp.log(jnp.diagonal(a, axis1=-1, axis2=-2)).sum(axis=-1)


# The is sourced from: torch.distributions.util.py
#
# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)
# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU                      (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
class lazy_property(object):
    r"""
    Used as a decorator for lazy loading of class attributes. This uses a
    non-data descriptor that calls the wrapped method to compute the property on
    first call; thereafter replacing the wrapped method into an instance
    attribute.
    """

    def __init__(self, wrapped):
        self.wrapped = wrapped
        update_wrapper(self, wrapped)

    # This is to prevent warnings from sphinx
    def __call__(self, *args, **kwargs):
        return self.wrapped(*args, **kwargs)

    def __get__(self, instance, obj_type=None):
        if instance is None:
            return self
        value = self.wrapped(instance)
        setattr(instance, self.wrapped.__name__, value)
        return value


def validate_sample(log_prob_fn):
    def wrapper(self, *args, **kwargs):
        log_prob = log_prob_fn(self, *args, **kwargs)
        if self._validate_args:
            value = kwargs["value"] if "value" in kwargs else args[0]
            mask = self._validate_sample(value)
            log_prob = jnp.where(mask, log_prob, -jnp.inf)
        return log_prob

    return wrapper


def add_diag(matrix: jnp.ndarray, diag: jnp.ndarray) -> jnp.ndarray:
    """
    Add `diag` to the trailing diagonal of `matrix`.
    """
    idx = jnp.arange(matrix.shape[-1])
    return matrix.at[..., idx, idx].add(diag)
