import torch
import torch.nn.functional as F
import einops
import math
import logging
import torch.distributions as D
import numpy as np
import pdb

_logger = logging.getLogger(__name__)


def _throw_if_not_1d(*args):
    for arg in args:
        if len(arg.shape) != 1:
            raise ValueError(f"Expected 1D array. Got shape {arg.shape}.")


def _throw_if_not_2d(*args):
    for arg in args:
        if len(arg.shape) != 2:
            raise ValueError(f"Expected 2D array. Got shape {arg.shape}.")


def to_npy(x):
    return x.detach().cpu().numpy()


def const_hazard(log_c, t):
    """log(prob), log(hazard) and integral(hazard) for a constant hazard fn.

    Works with arbitrary shapes.

    Let shape(log_c) = shape(t) = S
    Args:
        log_c: log of the constant c. Shape: S
        t: the time since the last event. Shape: S
    Returns:
        log_prob: log prob of the inter-event time. Shape: S
        log_h: log intensity of the hazard function. Shape: S
        H: log integral of the hazard function. Shape: S
    """
    assert log_c.shape == t.shape, f"{log_c.shape=} == {t.shape=}"
    log_h = log_c
    int_h = torch.exp(log_c) * t
    c = torch.exp(log_c)
    log_prob = log_c - c * t
    return log_prob, log_h, int_h


def const_hazard_log_prob(log_c, t):
    """Log prob. of inter-event time given a constant hazard function.

    Computes the log probability of observing an inter-event time given a
    piecewise-constant hazard function.

    A constant hazard function:

        h(t) = c

    Has an exponential probability density function:

        f(t) = c e^{-c t}

    The log probability evaluated at t is then:

    log∘f(t) = log_c - c t

    Args:
        log_c: log of the constant c. Shape: (batch_size,).
        t: the time since the last event. Shape: (batch_size,).
    Returns:
        log_prob: log prob of the inter-event time. Shape: (batch_size,).
    """
    _throw_if_not_1d(log_c, t)
    c = torch.exp(log_c)
    log_prob = log_c - c * t
    return log_prob


def exp_hazard(a, d, t):
    """log(prob), log(hazard) and integral(hazard) for an exponential hazard fn.

    Works with arbitrary shapes.

    Args:
        a: the scale parameter. Shape: (batch_size,).
        d: the offset parameter. Shape: (batch_size,).
        t: the time since the last event. Shape: (batch_size,).
    Returns:
        log_prob: log prob of the inter-event time. Shape: (batch_size,).
        log_h: log intensity of the hazard function. Shape: (batch_size,).
        H: log integral of the hazard function. Shape: (batch_size,).
    """
    assert a.shape == d.shape == t.shape, (a.shape, d.shape, t.shape)
    log_h = a * t + d
    int_h = torch.exp(d) * (torch.exp(a * t) - 1) / a
    log_prob = log_h - int_h
    return log_prob, log_h, int_h


def exp_hazard_log_prob(a, d, t):
    """
    Computes log prob. of inter-event time given an exponential hazard function.

    An exponential hazard function:

        h(t) = ke^{a t}

    or equivalently,

        h(t) = e^{a t + d}

    Integrating this gives:

        int_{0}^{t} h(s) ds
            = k/a (e^{a t} - 1)

            = e^{a t + d} - e^d

    and so the survival function is:

        S(t) = e^{-k/a (e^{a t} - 1)}

             = e^{1/a (e^{a t + d} - e^d)}

    The probability density function is a Gompertz distribution:

        f(t) = h(t) S(t)
             = ke^{a t} e^{-k/a (e^{a t} - 1)}

    The distribution has support [0, ∞).

    The log probability evaluated at t is then:

    log∘f(t) = log∘h(t) + log∘S(t)
             = log(k) + a t - k/a (e^{a t} - 1)

             = d + a t - e^d/a (e^{a t} - 1)

    Args:
        a: the scale parameter. Shape: (batch_size,).
        d: the offset parameter. Shape: (batch_size,).
        t: the time since the last event. Shape: (batch_size,).
    Returns:
        log_prob: log prob of the inter-event time. Shape: (batch_size,).
    """
    _throw_if_not_1d(a, d, t)
    log_h = a * t + d
    int_h = torch.exp(d) * (torch.exp(a * t) - 1) / a
    log_prob = log_h - int_h
    return log_prob


def piecewise_hazard_log_prob(log_c, t, piece_len):
    """Log prob. for a piecewise-constant hazard function.

    Computes log prob. of observing the given inter-event time `t` assuming
    that `log_c` are left-hand side values of a piecewise-constant hazard
    function. The hazard function is defined over the range
    [0, len(log_c)*piece_len).

    The hazard function:

        from [0, piece_len*1) is c[0]
        from [piece_len*1, piece_len*2) is c[1], etc.

    Args:
        log_c: log(c) values for h(0), h(piece_len), h(2*piece_len), ...
            h(n*piece_len). The values define the left-hand side of each piece
            and the right hand side of the last piece.
            Shape: [batch_size, num_pieces].
        t: the time since the last event. Shape: [batch_size, 1].
        piece_len: the length of each piece.
    Returns:
        log_prob: log prob of the inter-event time. Shape: [batch_size, 1].
    """
    _throw_if_not_1d(log_c, t)
    b, n_p = log_c.shape[0], log_c.shape[1]
    if t.shape != (b, 1):
        raise ValueError(
            f"t must have shape [batch_size, 1]. Got shape ({t.shape})."
        )
    c = torch.exp(log_c)
    t_idx = torch.floor(t / piece_len).long()
    indicies = einops.rearrange(
        torch.arange(n_p, device=log_c.device), "n -> 1 n"
    )
    if torch.any(t_idx >= n_p - 1):
        raise ValueError(
            "t must be less than the right-hand side of the last piece."
            f"Got t=({t})"
        )
    weight = t / piece_len - t_idx
    # 1. log h(t)
    # Linear interpolation between the two pieces. This is done in h(t),
    # and not in log(h(t)), which is also a possibility.
    log_h_t = c[:, t_idx]
    # 2. log S(t) = -int_{0}^{t} h(s) ds
    # Integration of piecewise-constant function is even simpler than the
    # trapezoid integration.
    integral_weights = (indicies < t_idx).float()
    # For the last piece, it's length is w.
    integral_weights[:, t_idx] = weight
    log_S_t = torch.dot(c, integral_weights) * piece_len
    res = log_h_t - log_S_t
    return res


def normal_cdf(x, mu, sigma):
    """Cumulative distribution for the normal distribution.

    Args:
        x: the value at which to evaluate the CDF.
        mu: the mean of the normal distribution.
        sigma: the standard deviation of the normal distribution.
    Returns:
        cdf: the value of the CDF at x.
    """
    _throw_if_not_1d(x, mu, sigma)
    cdf = 0.5 * (1 + torch.erf((x - mu) / (sigma * math.sqrt(2))))
    return cdf


def log_normal_cdf(x, mu, sigma):
    """Cumulative distribution for the log-normal distribution.

    Args:
        x: the value at which to evaluate the CDF.
        mu: the mean of the normal distribution.
        sigma: the standard deviation of the normal distribution.
    """
    _throw_if_not_1d(x, mu, sigma)
    cdf = normal_cdf(torch.log(x), mu, sigma)
    return cdf


def log_normal_lcdf(x, mu, sigma):
    _throw_if_not_1d(x, mu, sigma)
    lcdf = torch.special.log_ndtr((torch.log(x) - mu) / sigma)
    return lcdf


def logmix_log_prob(log_tau, mu, log_sigma, t):
    """For a logmix hazard, calculates a batch of log probabilities.

    Important: mu and sigma refer to the underlying normal distributions,
    not the log-normal distributions nor the resulting mixture.

    LogMix
    ------
    Short for a mixture of log-normal distributions.

    LogNormal
    ---------
    If Z~N(mu, sigma), and X = exp(Z), then X is said to be log-normal
    distributed with parameters mu and sigma. Note that mu and sigma are
    _not_ the mean and standard deviation of the log-normal distribution,
    but of the normal distribution from which X is derived.

    p(t) = ∑_i τ_i log_normal(t; μ_i, σ_i)
         = ∑_i τ_i 1/(x σ_i sqrt(2π)) exp( -(ln t - μ_i)²/(2σ_i²) )
         = ∑_i A exp( C )
    Args:
        log_tau: the log of the mixing coefficients. Shape: (B, M).
        mu: the mean of the underlying normal distribution. Shape: (B, M).
        log_sigma: the log of the standard deviation of the underlying
            normal distribution. Shape: (B, M).
        t: the values to evaluate for a log probability. Shape: (B,).
            Typically this will be times since the last event.
    Returns:
        log_prob: the log probability of the inter-event times. Shape: (B,).
    """
    B, M = log_tau.shape
    assert log_tau.shape == mu.shape == log_sigma.shape
    assert t.shape == (B,)

    # There are two separate considerations around cases of t=0. Firstly,
    # do we allow simultaneous events? Let's say the answer is "No". Even
    # so, we must contend with the question of floating point precision.
    # x1 != x2 doesn't imply that (x1 - x2) != 0. Consequently, we must
    # accept that t=0 is a possibility. In such a case, the probability
    # assigned to t=0 should be a probability representative of the
    # smallest possible time delta, which will depend on the floating
    # point precision. For now though, we will assert.
    if torch.any(t <= 0) or torch.any(torch.isnan(t)):
        raise ValueError(f"Inter-event times must be positive. {t=}")
    # Normalize tau.
    log_tau = torch.log_softmax(log_tau, dim=1)
    t_repeat = einops.repeat(t, "b -> b m", m=M)
    log_A = (
        log_tau - torch.log(t_repeat) - log_sigma - 0.5 * math.log(2 * math.pi)
    )
    C = -((torch.log(t_repeat) - mu) ** 2) / (2 * torch.exp(log_sigma * 2))
    log_prob = torch.logsumexp(log_A + C, dim=1)
    assert log_prob.shape == t.shape
    return log_prob


def logmix_interval_log_probv1(log_tau, mu, log_sigma, t0, t1):
    B, M = log_tau.shape
    assert log_tau.shape == mu.shape == log_sigma.shape
    assert t0.shape == t1.shape == (B,)
    assert torch.all(t0 >= 0), t0
    assert torch.all(t1 > t0), (t1, t0)
    tau = torch.softmax(log_tau, dim=1)
    sigma = torch.exp(log_sigma)
    # Need double precision for the log_normal_cdf function.
    mu_flat = einops.rearrange(mu, "b m -> (b m)").double()
    sigma_flat = einops.rearrange(sigma, "b m -> (b m)").double()
    t0_repeat = einops.repeat(t0, "b -> (b m)", m=M)
    t1_repeat = einops.repeat(t1, "b -> (b m)", m=M)
    cdf_0 = log_normal_cdf(t0_repeat, mu_flat, sigma_flat)
    cdf_1 = log_normal_cdf(t1_repeat, mu_flat, sigma_flat)
    cdf_0 = einops.rearrange(cdf_0, "(b m) -> b m", b=B, m=M)
    cdf_1 = einops.rearrange(cdf_1, "(b m) -> b m", b=B, m=M)
    res = (tau * (cdf_1 - cdf_0)).sum(dim=1)
    assert torch.isfinite(
        res
    ).all(), pdb.set_trace()  # res[torch.nonzero(~torch.isfinite(res))]
    if torch.any(res == 0):
        _logger.warning(f"interval log prob is zero.")
    assert torch.all((res >= 0) & (res <= 1.01))
    return res.log()


def logmix_interval_log_prob(log_tau, mu, log_sigma, t0, t1):
    B, M = log_tau.shape
    assert log_tau.shape == mu.shape == log_sigma.shape
    assert t0.shape == t1.shape == (B,)
    assert torch.all(t0 >= 0), t0
    assert torch.all(t1 > t0), (t1, t0)
    sigma = torch.exp(log_sigma)
    # Need double precision for the log_normal_cdf function.
    mu_flat = einops.rearrange(mu, "b m -> (b m)").double()
    sigma_flat = einops.rearrange(sigma, "b m -> (b m)").double()
    t0_repeat = einops.repeat(t0, "b -> (b m)", m=M)
    t1_repeat = einops.repeat(t1, "b -> (b m)", m=M)
    lcdf_0 = log_normal_lcdf(t0_repeat, mu_flat, sigma_flat)
    lcdf_1 = log_normal_lcdf(t1_repeat, mu_flat, sigma_flat)
    # log(CDF(t1) - CDF(t0)) in log-space safely
    expdifflog = torch.exp(lcdf_0 - lcdf_1)
    if torch.any(expdifflog == 1):
        # So, a model should score the worst possible score in terms of
        # logprob if this is encountered. If nan must be avoided, an option
        # could be to clamp to torch.finfo(torch.float64).tiny.
        idxs =  torch.nonzero(expdifflog == 1, as_tuple=False)
        _logger.warning(
            f"0 interval prob -> logprob is 1/infinity (nan). ({len(idxs)=})"
        )
    lcdf_diff = lcdf_1 + torch.log1p(-torch.exp(lcdf_0 - lcdf_1))
    lcdf_diff = einops.rearrange(lcdf_diff, "(b m) -> b m", b=B, m=M)
    log_norm = torch.logsumexp(log_tau, dim=1)
    res = torch.logsumexp(log_tau + lcdf_diff, dim=1) - log_norm
    if not torch.all(torch.isfinite(res)):
        idxs =  torch.nonzero(~torch.isfinite(res), as_tuple=False)
        error_values = to_npy(res[idxs])
        _logger.warning(
            f"Non-finite interval logprob: {error_values=}, {to_npy(idxs)=}, "
            f"{(t0[idxs])=}, {to_npy(t1[idxs])=}"
        )
        #pdb.set_trace()
    return res


def logmix_interval_log_prob2(log_tau, mu, log_sigma, t0, t1):
    """
    Same as logmix_interval_log_prob, but handles case of no gt spikes.

    No gt spikes is indicated by negative t0, and means that the interval
    from 0 to -t1 is free of spikes.

    Note: currently unused.
    """
    B, M = log_tau.shape
    # t0, when <0, encodes steps until end of recording.
    # When t0 < 0, we will calculate prob(-t0, infty) by calculating
    # prob(0, -t0) and subtracting from 1.
    assert not torch.any(t0 == t1), "t0 == t1 is not allowed."
    invert = torch.where(
        t0 < 0,
        torch.tensor(True, device=t0.device),
        torch.tensor(False, device=t0.device),
    )
    t1 = torch.where(invert, -t0, t1)
    t0 = torch.where(invert, torch.tensor(0.0, device=t0.device), t0)
    assert log_tau.shape == mu.shape == log_sigma.shape
    assert t0.shape == t1.shape == (B,)
    sigma = torch.exp(log_sigma)
    # Need double precision for the log_normal_cdf function.
    mu_flat = einops.rearrange(mu, "b m -> (b m)").double()
    sigma_flat = einops.rearrange(sigma, "b m -> (b m)").double()
    t0_repeat = einops.repeat(t0, "b -> (b m)", m=M)
    t1_repeat = einops.repeat(t1, "b -> (b m)", m=M)
    lcdf_0 = log_normal_lcdf(t0_repeat, mu_flat, sigma_flat)
    lcdf_1 = log_normal_lcdf(t1_repeat, mu_flat, sigma_flat)
    # log(CDF(t1) - CDF(t0)) in log-space safely
    lcdf_diff = lcdf_1 + torch.log1p(-torch.exp(lcdf_0 - lcdf_1))
    lcdf_diff = einops.rearrange(lcdf_diff, "(b m) -> b m", b=B, m=M)
    log_norm = torch.logsumexp(log_tau, dim=1)
    res = torch.logsumexp(log_tau + lcdf_diff, dim=1) - log_norm
    # When gt has no more spikes, we calculate prob(-t1, infty) = 1-prob(0, -t1)
    res = torch.where(invert, 1 - res, res)
    if not torch.all(torch.isfinite(res)):
        idxs =  torch.nonzero(~torch.isfinite(res), as_tuple=False)
        error_values = to_npy(res[idxs])
        _logger.warning(
            f"Non-finite interval logprob: {error_values=}, {to_npy(idxs)=}, "
            f"{(t0[idxs])=}, {to_npy(t1[idxs])=}"
        )
        pdb.set_trace()
    return res


def logmix_expected_val(log_tau, mu, log_sigma):
    """
    p(t) = ∑_i τ_i log_normal(t; μ_i, σ_i)
    """
    B, M = log_tau.shape
    assert log_tau.shape == mu.shape
    sigma = torch.exp(log_sigma)
    tau = torch.softmax(log_tau, dim=1)

    res = (tau * torch.exp(mu + 0.5 * sigma**2)).sum(dim=1)
    assert res.shape == (B,)
    return res


def logmix_sample(log_tau, mu, log_sigma, n_samples):
    B, M = log_tau.shape
    w = D.Categorical(logits=log_tau)
    comp = D.LogNormal(mu, torch.exp(log_sigma))
    mix = D.MixtureSameFamily(w, comp)
    try:
        res = mix.sample((n_samples,))
    except RuntimeError as e:
        import pdb

        pdb.set_trace()
    return res


def logmix_variance(log_tau, mu, log_sigma):
    """
    Computes the variance of a mixture of lognormal distributions.

    Args:
        log_tau: Tensor of shape (B, M) - log of mixture weights.
        mu: Tensor of shape (B, M) - means of the lognormal components.
        log_sigma: Tensor of shape (B, M) - log of standard deviations of the lognormal components.

    Returns:
        Tensor of shape (B,) representing the variances of the mixtures.
    """
    B, M = log_tau.shape
    assert log_tau.shape == mu.shape == log_sigma.shape

    sigma = torch.exp(log_sigma)  # Convert log σ to σ
    weights = torch.softmax(
        log_tau, dim=1
    )  # Convert log τ to normalized weights

    # Compute log terms for numerical stability
    log_second_moments = 2 * mu + 2 * sigma**2
    log_first_moments = mu + 0.5 * sigma**2

    # Stability: log-sum-exp for weighted second moment
    max_log_second_moment, _ = log_second_moments.max(dim=1, keepdim=True)
    max_log_first_moment, _ = log_first_moments.max(dim=1, keepdim=True)

    log_weighted_second_moment = (
        torch.log(
            (
                weights * torch.exp(log_second_moments - max_log_second_moment)
            ).sum(dim=1)
        )
        + max_log_second_moment.squeeze()
    )

    log_weighted_first_moment = (
        torch.log(
            (weights * torch.exp(log_first_moments - max_log_first_moment)).sum(
                dim=1
            )
        )
        + max_log_first_moment.squeeze()
    )

    # Convert back from log-space and compute variance
    weighted_second_moment = torch.exp(log_weighted_second_moment.double())
    weighted_first_moment_squared = torch.exp(
        2 * log_weighted_first_moment.double()
    )

    variance = weighted_second_moment - weighted_first_moment_squared

    # Sanity check
    assert variance.shape == (
        B,
    ), f"Unexpected variance shape: {variance.shape}"

    if not torch.all((torch.isfinite(variance)) & (variance >= 0)):
        idxs = torch.nonzero(~torch.isfinite(variance), as_tuple=False)
        error_values = variance[idxs].cpu().numpy()
        _logger.warning(f"Non-finite variance: {error_values=}, {idxs=}")

    return variance


def logmix_median(log_tau, mu, log_sigma, n_samples=2048):
    # Calculate mean and variance, just for creating an expected bound around
    # the median so as to catch errors.
    expected_val = logmix_expected_val(log_tau, mu, log_sigma)
    samples = logmix_sample(log_tau, mu, log_sigma, n_samples)
    median = samples.median(dim=0).values

    def sanity_check():
        var = logmix_variance(log_tau, mu, log_sigma)
        min_median = expected_val - var.sqrt()
        max_median = expected_val + var.sqrt()
        # The variance can be massive, and it's calculation can result in nans
        # or infs. So, only bother to check the cases where the variance is
        # finite.
        within_expected_range = (
            ~torch.isfinite(min_median)  # don't bother testing these cases
            | ~torch.isfinite(max_median)  # don't bother testing these cases
            | (median >= min_median)
            | (median <= max_median)
        )

        if not torch.all(within_expected_range):
            idxs = torch.nonzero(~within_expected_range, as_tuple=False)
            low_val_high = (
                torch.stack(
                    [
                        min_median[idxs],
                        median[idxs],
                        max_median[idxs],
                    ]  # , dim=1
                )
                .cpu()
                .numpy()
            )
            _logger.warning(f"Median outside expected range. {low_val_high=}")

    sanity_check()

    return median


def logsubexp(a, b):
    """
    Numerically (more) stable version of log(exp(a) - exp(b)).
    Like torch.logsumexp().

    The following example highlights the strategy:

        2^-3 - 2^-4 = 2^-3 * (1 - 2^-1)


    log(exp(a) - exp(b))
      = log(exp(a) * (1 - exp(b - a)))
      = log(exp(a)) + log(1 - exp(b - a))
      = a + log(1 - exp(b - a))
    And to use torch's log1p() function:
      = a + log1p(-exp(b - a))
    """
    if torch.any(a < b):
        raise ValueError(f"a must be greater than b ({a=}, {b=})")
    res = a + torch.log1p(-torch.exp(b - a))
    neg_inf = torch.tensor([-float("inf")], device=a.device)
    res = torch.where(a == b, neg_inf, res)
    if torch.any(torch.isinf(res)):
        num_inf = torch.nonzero(torch.isinf(res)).shape[0]
        _logger.warning(f"logsubexp: res is infinite. Num inf: {num_inf}")
    if torch.any(torch.isnan(res)):
        raise ValueError(f"logsubexp: res contains NaN: {res}")
    return res


def interval_prob(probs, bin_edges, x_from, x_to):
    """
    Calculate probability mass within interval [x-width/2, x+width/2]
    """
    b, n_bins = probs.shape
    assert n_bins == len(bin_edges) - 1, f"{n_bins=} != {len(bin_edges)=}"
    leftmost, rightmost = bin_edges[0], bin_edges[-1]
    dx = rightmost - leftmost

    # Convert to index space.
    # It's possible we are out of range of the log_probs. Clip to range.
    # old
    # query_left = torch.clamp((x_from - leftmost) / dx, min=0, max=n_bins-1)
    # query_right = torch.clamp((x_to - leftmost) / dx, min=0, max=n_bins-1)
    # new
    query_left = torch.clamp(x_from - leftmost, min=0, max=n_bins - 1)
    query_right = torch.clamp(x_to - leftmost, min=0, max=n_bins - 1)

    weights = torch.logical_and(
        bin_edges[0:-1] < einops.rearrange(query_left, "b -> b 1"),
        bin_edges[1:] > einops.rearrange(query_right, "b -> b 1"),
    ).float()

    # Get fractions for edge bins
    left_idx = torch.floor(query_left).long()
    right_idx = torch.floor(query_right).long()
    # If within a single bin, use left_frac and set right_frac to 0.
    right_frac = torch.where(
        left_idx == right_idx, 0.0, query_right - right_idx
    )
    # (right - left) is less than (left + 1 - left) only when right is less
    # than the right edge of the bin.
    left_frac = torch.minimum(
        (left_idx + 1.0) - query_left, query_right - query_left
    )
    # right_frac = query_right - right_idx
    # When left_idx == right_idx, the interval is within a single bin.

    batch_idx = torch.arange(b, device=probs.device)
    assert torch.all(
        torch.logical_and(0 <= left_idx, left_idx < n_bins)
    ), f"{left_idx=}"
    assert torch.all(
        torch.logical_and(0 <= right_idx, right_idx < n_bins)
    ), f"{right_idx=}"
    weights[batch_idx, left_idx] = left_frac
    weights[batch_idx, right_idx] = right_frac

    res_prob = (weights * probs).sum(dim=-1)
    res = res_prob.log()
    return res


def interval_prob2(log_probs, leftmost_edge, rightmost_edge, x_from, x_to):
    """
    Calculate probability mass within interval [x-width/2, x+width/2]

    The first and last bins get special treatment so as to cover the full
    positive real line:
      - the first bin extends from 0 to the next edge.
      - the last bin is a not a bin but an exponential tail.

    Parameters:
    -----------
    log_probs : array-like of shape (B, N)
        Unnormalized log probabilities for each interval
    leftmost_edge, rightmost_edge : float
        Start and end of the full range
    x_from, x_to: (B, )
        The interval to calculate the probability mass for

    Returns:
    --------
    float
        Normalized probability mass within the specified interval
    """
    b, n_bins = log_probs.shape

    # Compute interval with fixed width
    # Convention, if x-width/2 is less than zero, use [0, width]
    dx = (rightmost_edge - leftmost_edge) / n_bins

    # It's possible we are out of range of the log_probs. Clip to range.
    # We will take the intersection of the interval with the range, possibly
    # resulting in a smaller (maybe even empty) interval.
    x_from = torch.clamp(x_from, min=leftmost_edge, max=rightmost_edge)
    x_to = torch.clamp(x_to, min=leftmost_edge, max=rightmost_edge)

    # Convert to indices
    query_left = (x_from - leftmost_edge) / dx
    query_right = (x_to - leftmost_edge) / dx

    bin_edges = torch.linspace(
        leftmost_edge, rightmost_edge, n_bins + 1, device=log_probs.device
    )
    # Start by filling in weights of the full bins.
    # weights = torch.logical_and(
    #     bin_edges[0:-1] < einops.rearrange(query_left, "b -> b 1"),
    #     bin_edges[1:] > einops.rearrange(query_right, "b -> b 1"),
    # ).float()

    # Get fractions for edge bins
    each_bin_left = torch.maximum(
        bin_edges[:-1], einops.rearrange(query_left, "b -> b 1")
    )
    each_bin_right = torch.minimum(
        bin_edges[1:], einops.rearrange(query_right, "b -> b 1")
    )
    bin_frac = torch.clamp(each_bin_right - each_bin_left, min=0.0, max=1.0)
    weights = torch.zeros(
        (b, n_bins), dtype=log_probs.dtype, device=log_probs.device
    )
    batch_idx = torch.arange(b, device=log_probs.device)
    weights[batch_idx] = bin_frac
    unnormalized = (torch.log(weights) + log_probs).logsumexp(dim=-1)
    res = unnormalized - log_probs.logsumexp(dim=-1)
    return res


def interval_prob3(log_probs, leftmost_edge, rightmost_edge, x_from, x_to):
    """
    Calculate probability mass within interval [x-width/2, x+width/2]

    The first and last bins get special treatment so as to cover the full
    positive real line:
      - the first bin extends from 0 to the next edge.
      - the last bin is a not a bin but an exponential tail.

    Parameters:
    -----------
    log_probs : array-like of shape (B, N)
        Unnormalized log probabilities for each interval
    leftmost_edge, rightmost_edge : float
        Start and end of the full range
    x_from, x_to: (B, )
        The interval to calculate the probability mass for

    Returns:
    --------
    float
        Normalized probability mass within the specified interval
    """
    b, n_bins = log_probs.shape
    if torch.any(x_from > x_to):
        raise ValueError("x_from must be less than x_to.")

    # Compute interval with fixed width
    # Convention, if x-width/2 is less than zero, use [0, width]
    dx = (rightmost_edge - leftmost_edge) / n_bins
    # The first bin extends from 0 to the next edge.
    dx0 = leftmost_edge + dx

    # Convert to indices
    # query_left = (x_from - leftmost_edge) / dx
    # query_right = (x_to - leftmost_edge) / dx
    query_left = x_from
    query_right = x_to

    bin_edges = torch.linspace(
        leftmost_edge, rightmost_edge, n_bins + 1, device=log_probs.device
    )
    bin_edges[0] = 0.0
    bin_edges[-1] = float("inf")

    # Get fractions for edge bins
    each_bin_left = torch.maximum(
        bin_edges[:-1], einops.rearrange(query_left, "b -> b 1")
    )
    # Don't let the left query be greater than the right edge of the bin.
    each_bin_left = torch.minimum(each_bin_left, bin_edges[1:])
    each_bin_right = torch.minimum(
        bin_edges[1:], einops.rearrange(query_right, "b -> b 1")
    )
    # Don't let the right query be less than the left edge of the bin.
    each_bin_right = torch.maximum(each_bin_right, bin_edges[:-1])
    # Not needed anymore.
    # bin_frac = torch.clamp(each_bin_right - each_bin_left, min=0.0, max=1.0)
    bin_frac = each_bin_right - each_bin_left

    # RHS bin: exponential tail
    # The last bin is a tail extending to infinity. So the fraction of its
    # mass will not just be the simple difference between right and left.
    # exponential mass(a, b) = exp(-a) - exp(-b)
    exp_start = rightmost_edge - dx
    last_bin_weight = torch.exp(exp_start - each_bin_left[:, -1]) - torch.exp(
        exp_start - each_bin_right[:, -1]
    )
    bin_frac[:, -1] = last_bin_weight
    # LHS bin: stretched to 0, so shrink the mass accordingly.
    bin_frac[:, 0] = bin_frac[:, 0] / dx0
    # Center bins: scale by dx
    bin_frac[:, 1:-1] = bin_frac[:, 1:-1] / dx

    weights = torch.zeros(
        (b, n_bins), dtype=log_probs.dtype, device=log_probs.device
    )
    batch_idx = torch.arange(b, device=log_probs.device)
    weights[batch_idx] = bin_frac
    unnormalized = (torch.log(weights) + log_probs).logsumexp(dim=-1)
    res = unnormalized - log_probs.logsumexp(dim=-1)
    return res


def auto_ll(m_out, target_t, t_min, t_max, n_bins):
    """Log-likelihood for the auto bin mode.

    Takes in torch tensors.

    This mode needs special treatment as the first and last bin are not
    the same as the others. The first bin is extended to 0, and the last
    is an exponential tail.
    """
    bin_width = (t_max - t_min) / n_bins
    B, L = m_out.shape
    D = m_out.device
    DT = m_out.dtype
    # Move out
    # assert len(target_t.shape) == 2
    # target_t = target_t[:, 0].float().to(D)
    t_idx = torch.floor((target_t - t_min) / bin_width).long()
    t_idx = torch.clamp(t_idx, 0, n_bins - 1)
    # First bin has a standard bin's worth, plus the length to 0.
    first_bin_scale = torch.full(
        (B, 1), -math.log(t_min + bin_width), dtype=DT, device=D
    )
    last_bin_scale = -(target_t - (t_max - bin_width))
    last_bin_scale = einops.rearrange(last_bin_scale, "b -> b 1", b=B)
    mid_bin_scale = torch.full(
        (B, n_bins - 2), -math.log(bin_width), dtype=DT, device=D
    )
    scale = torch.cat([first_bin_scale, mid_bin_scale, last_bin_scale], dim=1)
    m_out = F.log_softmax(m_out, dim=1)
    assert scale.shape == m_out.shape
    # Log space, so add.
    lprobs = m_out + scale
    t_one_hot = F.one_hot(t_idx, n_bins).float()
    ll = torch.einsum("b r, b r -> b", lprobs, t_one_hot)
    return ll


def to_bin_idx(t, bin_edges):
    """Converts a time to a bin index for variable width bins.

    It's easy to misunderstand the behaviour of bucketize, so the call is
    wrapped here in a function.

    right=True
    ----------
    right=True means that if the query t is equal to the left edge of a bucket,
    then it will be placed in that bucket. Example for when it's important:

      Case:   t = 0, bin_edges = [0, 1, 2]
      Desired result: we want the bin whose edges are [0, 1] to be selected.

    -1
    --
    bucketize will match a bin like:  edges[i-1] <= t < edges[i], and then
    return i for a match. This means that i is always 1 greater than the
    actual bin index (0-based). So, we subtract 1 to get the correct index.

    """
    res = torch.bucketize(t, bin_edges, right=True) - 1
    return res


def var_bin_expected_val(m_out, bin_edges, last_bin_lambda=1.0):
    B, L = m_out.shape
    assert bin_edges.shape == (L + 1,), f"{bin_edges.shape=} ({bin_edges=})"
    probs = F.softmax(m_out, dim=1)
    bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
    until_last = torch.einsum("b r, r -> b", probs[:, :-1], bin_centers[:-1])
    last = probs[:, -1] * (1 / last_bin_lambda + bin_edges[-2])
    expected_val = until_last + last
    return expected_val


@torch.no_grad()
def var_bin_mode(m_out, bin_edges, last_bin_lambda=1.0):
    """Calc the mode (first only, if multiple) of a variable bin distribution."""
    B, L = m_out.shape
    assert bin_edges.shape == (L + 1,)
    bin_widths = bin_edges[1:] - bin_edges[:-1]
    bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
    lprobs = F.log_softmax(m_out, dim=1)
    lprob_density = lprobs - torch.log(bin_widths)
    max_ldensity, mode_bin_standard = torch.max(lprob_density, dim=1)
    max_tail_ldensity = math.log(last_bin_lambda) + lprobs[:, -1]
    mode = torch.where(
        max_ldensity >= max_tail_ldensity,
        bin_centers[mode_bin_standard],
        torch.full_like(mode_bin_standard, fill_value=bin_edges[-2]),
    )
    return mode


@torch.no_grad()
def var_bin_median(m_out, bin_edges, last_bin_lambda=1.0):
    B, N = m_out.shape
    D = m_out.device
    log_cdf = torch.logcumsumexp(m_out, dim=1) - torch.logsumexp(
        m_out, dim=1, keepdim=True
    )
    query = torch.full((B, 1), math.log(0.5), device=D)
    median_idx = einops.rearrange(
        torch.searchsorted(log_cdf, query, right=True), "b 1 -> b"
    )

    batch_indices = torch.arange(B, device=D)
    left_cdf = torch.where(
        median_idx > 0,
        torch.exp(log_cdf[batch_indices, median_idx - 1]),
        torch.zeros(B, device=D),
    )
    right_cdf = torch.exp(log_cdf[batch_indices, median_idx])
    mass_in_right_bin = right_cdf - left_cdf
    right_bin_len = bin_edges[median_idx + 1] - bin_edges[median_idx]
    remaining_mass = 0.5 - left_cdf
    bin_frac_needed = remaining_mass / mass_in_right_bin
    plus_x = torch.where(
        median_idx < N - 1,
        # linear interpolation
        right_bin_len * bin_frac_needed,
        # Last bin: inverse cdf for exponential distribution.
        -torch.log(1 - bin_frac_needed) / last_bin_lambda,
    )
    median = bin_edges[median_idx] + plus_x
    return median


@torch.no_grad()
def var_bin_cdf(m_out, bin_edges, x, last_bin_lambda=1.0):
    B, N = m_out.shape
    D = m_out.device
    assert bin_edges.shape == (N + 1,)
    assert x.shape == (B,)
    batch_indices = torch.arange(B, device=D)
    lcdf = torch.logcumsumexp(m_out, dim=1) - einops.rearrange(
        torch.logsumexp(m_out, dim=1), "b -> b 1"
    )
    x_bin = to_bin_idx(x, bin_edges)
    left_cdf = torch.where(
        x_bin > 0,
        torch.exp(lcdf[batch_indices, x_bin - 1]),
        torch.zeros(B, device=D),
    )
    right_cdf = torch.exp(lcdf[batch_indices, x_bin])
    mass_in_right_bin = right_cdf - left_cdf
    right_bin_len = bin_edges[x_bin + 1] - bin_edges[x_bin]
    plus_x = x - bin_edges[x_bin]
    remaining_mass = torch.where(
        x_bin < N - 1,
        # linear interpolation
        mass_in_right_bin * plus_x / right_bin_len,
        # exponential tail
        mass_in_right_bin * (1 - torch.exp(-last_bin_lambda * plus_x)),
    )
    cdf = left_cdf + remaining_mass
    return cdf


@torch.no_grad()
def var_bin_interval_prob(m_out, bin_edges, x_from, x_to, last_bin_lambda=1.0):
    from_cdf = var_bin_cdf(m_out, bin_edges, x_from, last_bin_lambda)
    to_cdf = var_bin_cdf(m_out, bin_edges, x_to, last_bin_lambda)
    res = to_cdf - from_cdf
    return res


def var_bin_ll(m_out, bin_edges, target_t, last_bin_lambda=1.0):
    """Log-likelihood for the auto bin mode.
    This mode needs special treatment as the first and last bin are not
    the same as the others. The first bin is extended to 0, and the last
    is an exponential tail.
    """
    B, N = m_out.shape
    assert target_t.shape == (B,)
    bin_widths = bin_edges[1:] - bin_edges[:-1]
    t_idx = to_bin_idx(target_t, bin_edges)
    # p(x) = λ exp(-λ x) => log p(x) = log λ - λ x
    last_bin_log_scale = math.log(last_bin_lambda) - last_bin_lambda * (
        target_t - bin_edges[-2]
    )
    bin_log_scale = torch.where(
        t_idx < N - 1, -torch.log(bin_widths[t_idx]), last_bin_log_scale
    )
    lprobs = F.log_softmax(m_out, dim=1)
    ll = lprobs[torch.arange(B), t_idx] + bin_log_scale
    return ll


def estimate_diff_entropy(samples, n_bins, coverage=0.999):
    """Fixed-width histogram density estimate of differential entropy."""
    lower, upper = np.quantile(
        samples, [(1 - coverage) / 2, 1 - (1 - coverage) / 2]
    )
    counts, bin_edges = np.histogram(
        samples, bins=n_bins, range=(lower, upper), density=False
    )
    bin_width = bin_edges[1] - bin_edges[0]
    probs = counts / (np.sum(counts) * bin_width)
    mask = probs > 0
    entropy = -np.sum(probs[mask] * np.log(probs[mask])) * bin_width
    return entropy


def _trapezoidal_hazard_log_prob(log_c, t, piece_len):
    """Log prob. for a sampled hazard function with linear interpolation.

    Note: unused

    Computes log prob. of observing the given inter-event time `t` assuming
    that `log_c` are samples of the hazard function at t=0, t=piece_len, etc.
    The hazard function is defined over the range [0, len(log_c)*piece_len).

    The hazard function:

        from [0, piece_len*1) is c[0]
        from [piece_len*1, piece_len*2) is c[1], etc.

    Args:
        log_c: log(c) values for h(0), h(piece_len), h(2*piece_len), ...
            h(n*piece_len). The values define the left-hand side of each piece
            and the right hand side of the last piece.
            Shape: [batch_size, num_pieces].
        t: the time since the last event. Shape: [batch_size, 1].
        piece_len: the length of each piece.
    Returns:
        log_prob: log prob of the inter-event time. Shape: [batch_size, 1].

    Questions: is this the best way to define the samples? How about the
    centers? The right-hand side could be defined as the left-hand side
    of the next piece, and then use a Taylor series to extrapolate into
    that piece (instead of linear interpolation between pieces).
    """
    _throw_if_not_1d(log_c, t)
    b, n_p = log_c.shape[0], log_c.shape[1]
    if t.shape != (b, 1):
        raise ValueError(
            f"t must have shape [batch_size, 1]. Got shape ({t.shape})."
        )
    c = torch.exp(log_c)
    t_idx = torch.floor(t / piece_len).long()
    indicies = einops.rearrange(
        torch.arange(n_p, device=log_c.device), "n -> 1 n"
    )
    if torch.any(t_idx >= n_p - 1):
        raise ValueError(
            "t must be less than the right-hand side of the last piece."
            f"Got t=({t})"
        )
    weight = t / piece_len - t_idx
    # 1. log h(t)
    # Linear interpolation between the two pieces. This is done in h(t),
    # and not in log(h(t)), which is also a possibility.
    log_h_t = (1 - weight) * c[:, t_idx] + weight * c[:, t_idx + 1]
    # 2. log S(t) = -int_{0}^{t} h(s) ds
    # Numeric integration with the trapezoidal rule until the left-hand side
    # of the current piece, then a partial integration of that piece.
    integral_weights = (indicies < t_idx).float()
    # Trapezoid integration with inclusive boundaries:
    #   H = 1/2 (h(0) + 2*h(1) + 2*h(2) + ... + 2*h(n-2) + h(n-1))
    integral_weights[:, 0] = 0.5
    # For the last piece, interpolate the point, then integrate to there.
    # area = w * (h(i) + (1-w)*h(i) + w*h(i+1)) / 2
    #      = h(i) * (1 - w/2) + h(i+1) * w^2 / 2
    integral_weights[:, t_idx] = 0.5 + weight(1 - weight / 2)
    integral_weights[:, t_idx + 1] = weight**2 / 2
    log_S_t = torch.dot(c, integral_weights) * piece_len
    res = log_h_t - log_S_t
    return res
