import pytest
import math
import numpy as np
import scipy
import kdtpp.prob
import torch
import torch.testing
from scipy.special import erf
import einops


def test_log_normal_cdf():
    """
    For a range of (x, mu, sigma) tuples, check that the log-normal CDF
    matches the scipy implementation and what we expect.
    """
    # fmt: off
    xy = [# (x, mu, sigma, result)
            [0,           0, 1,  0.0],
            [1,           0, 1,  0.5],
            [math.exp(7), 0, 1,  1.0], 
            [0,           2, 1,  0.0],
            [math.exp(2), 2, 1,  0.5],
            [7,           2, 1,  0.5*(1+erf((math.log(7) - 2)/(math.sqrt(2))))],
            [0,           0, 2,  0.0],
            [1,           0, 2,  0.5],
          ]
    # fmt: on

    # 1. Hardcoded values
    xs, mus, sigmas, expected = torch.tensor(xy).T
    res = kdtpp.prob.log_normal_cdf(xs, mus, sigmas)
    assert np.allclose(res.numpy(), expected, atol=1e-8)

    # 2. scipy as ground truth
    xs = torch.linspace(0, 20, 100)
    mus = torch.tensor([0, 0.1, 5.2, 28, 501])
    sigmas = torch.tensor([0.1, 1, 2, 5, 10])
    for mu in mus:
        for sigma in sigmas:
            ss = torch.full_like(xs, sigma)
            mm = torch.full_like(xs, mu)
            res = kdtpp.prob.log_normal_cdf(xs, mm, ss)
            scipy_res = scipy.stats.lognorm.cdf(xs, s=ss, scale=mm.exp())
            assert np.allclose(res, scipy_res, atol=1e-7)


def test_interval_prob2():
    """
    Tests that:

        1. For a uniform distribution with 5 bins, a range of interval lengths
           produces the expected probabilities.
    """
    # 1. Basic test, 5 equally probable intervals of length 1, from 0 to 5.
    # Setup
    probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
    bin_lprobs = torch.log(probs)
    lhs_edge, rhs_edge = 0, 5
    test_lens = [1, 2, 0.5, 0.1, 1 / 3, 1 / 100, 1 / 742]

    def query_intervals(interval_len):
        boundaries = torch.arange(
            lhs_edge, rhs_edge + interval_len / 2, interval_len
        )
        froms = boundaries[:-1]
        tos = boundaries[1:]
        return froms, tos

    # Test
    for i, tl in enumerate(test_lens):
        froms, tos = query_intervals(tl)
        lprobs = kdtpp.prob.interval_prob2(
            einops.repeat(bin_lprobs, "n -> b n", b=len(froms)),
            lhs_edge,
            rhs_edge,
            froms,
            tos,
        )
        probs = torch.exp(lprobs)
        expected_prob = torch.ones_like(probs) * tl * 1 / 5
        assert torch.allclose(
            probs, expected_prob, atol=1e-7
        ), f"Incorrect probability. Iteration {i=}"


def test_interval_prob3(np_rng):
    """
    interval_prob3 has special first and last bins: the first extends to zero,
    and the last extends to infinity via an exponential tail.

    Tests that:

      1. Very large interval gets close to 1 probability mass.
      2. Zero length interval gets zero probability mass.
      3. Interval in first bin is scaled correctly.
         3a. Hard-coded test.
         3b. Randomized function parameters.
      4. Interval in last bin uses exponential distribution correctly.
        4a. Hard-coded test.
        4b. Randomized function parameters.
      5. Negative interval (a, b) with a > b throws an error.
    """

    def test1():
        # Test 1. Very large interval gets close to 1 probability mass.
        # Setup
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        froms = torch.tensor([0])
        tos = torch.tensor([1000])
        # Test
        lprobs = kdtpp.prob.interval_prob3(
            einops.repeat(bin_lprobs, "n -> b n", b=len(froms)),
            lhs_edge,
            rhs_edge,
            froms,
            tos,
        )
        probs = torch.exp(lprobs)
        assert torch.allclose(probs, torch.ones_like(probs), atol=1e-7)

    test1()

    def test2():
        # Test 2. Zero length interval gets zero probability mass.
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        B = 128
        qs = torch.tensor(np_rng.uniform(0, 20, B), dtype=torch.float32)
        froms = tos = qs
        lprobs = kdtpp.prob.interval_prob3(
            einops.repeat(bin_lprobs, "n -> b n", b=B),
            lhs_edge,
            rhs_edge,
            froms,
            tos,
        )
        probs = torch.exp(lprobs)
        assert torch.allclose(probs, torch.zeros_like(probs), atol=1e-7)

    test2()

    def test3a():
        # Test 3. Interval in first bin is scaled correctly.
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        # Some hard-coded input-outputs.
        # The first bin is from [0, 6], and has an overall probability of 1/5.
        froms_tos_ans = torch.tensor(
            [
                [0, 6, 1 / 5 * 6 / 6],
                [0, 1, 1 / 5 * 1 / 6],
                [0, 2, 1 / 5 * 2 / 6],
                [0, 3, 1 / 5 * 3 / 6],
                [0, 4, 1 / 5 * 4 / 6],
                [1, 2, 1 / 5 * 1 / 6],
                [2, 4, 1 / 5 * 2 / 6],
                [3.5, 4, 1 / 5 * 1 / 12],
            ]
        )
        froms = froms_tos_ans[:, 0]
        tos = froms_tos_ans[:, 1]
        lprobs = kdtpp.prob.interval_prob3(
            einops.repeat(bin_lprobs, "n -> b n", b=len(froms)),
            lhs_edge,
            rhs_edge,
            froms,
            tos,
        )
        probs = torch.exp(lprobs)
        expected_probs = froms_tos_ans[:, 2]
        assert torch.allclose(probs, expected_probs, atol=1e-7)

    test3a()

    def test3b():
        """Same as test3a, but with random probs and intervals."""
        N = 100
        for i in range(N):
            n_bins = np_rng.integers(1, 200)
            probs = np_rng.dirichlet(np.ones(n_bins))
            bin_lprobs = torch.log(torch.tensor(probs, dtype=torch.float32))
            MAX_EDGE = 100
            two_edges = np_rng.uniform(0, MAX_EDGE, 2)
            lhs_edge, rhs_edge = min(two_edges), max(two_edges)
            B = 128
            second_edge = lhs_edge + (rhs_edge - lhs_edge) / n_bins
            # Only queries in the first bin (before the second edge).
            qs = torch.tensor(
                np_rng.uniform(0, second_edge, (2, B)), dtype=torch.float32
            )
            froms = torch.min(qs, axis=0).values
            tos = torch.max(qs, axis=0).values
            lprobs = kdtpp.prob.interval_prob3(
                einops.repeat(bin_lprobs, "n -> b n", b=B),
                lhs_edge,
                rhs_edge,
                froms,
                tos,
            )
            expected_prob = (tos - froms) * probs[0] / second_edge
            assert torch.allclose(torch.exp(lprobs), expected_prob, atol=1e-7)

    test3b()

    def test4a():
        # Test 4a. Last bin is handled correctly: as an exponential tail.
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        # Some hard-coded input-outputs.
        # The first bin is from [0, 6], and has an overall probability of 1/5.
        # fmt: off
        froms_tos_ans = torch.tensor(
            [
                [9,         9,   1 / 5 * 0.0], # Zero length interval.
                [9,       1e8,   1 / 5 * 1.0], # Full length interval.
                [9,        10,   1 / 5 * (1 - math.exp(-1))], 
                [9,        15,   1 / 5 * (1 - math.exp(-6))],
                [9,      13.2,   1 / 5 * (1 - math.exp(-4.2))],
                [10,       10,   1 / 5 * 0.0],
                [10,       11,   1 / 5 * (math.exp(-1) - math.exp(-2))],
                [10,     12.3,   1 / 5 * (math.exp(-1) - math.exp(-3.3))],
                [12.2,   13.7,   1 / 5 * (math.exp(-3.2) - math.exp(-4.7))],
            ]
        )
        # fmt: on
        froms = froms_tos_ans[:, 0]
        tos = froms_tos_ans[:, 1]
        batch_bin_lprobs = einops.repeat(bin_lprobs, "n -> b n", b=len(froms))
        lprobs = kdtpp.prob.interval_prob3(
            batch_bin_lprobs,
            lhs_edge,
            rhs_edge,
            froms,
            tos,
        )
        probs = torch.exp(lprobs)
        expected_probs = froms_tos_ans[:, 2]
        assert torch.allclose(probs, expected_probs, atol=1e-7)

    test4a()

    def test4b():
        """Same as test4a, but with random probs and intervals."""
        N = 100
        for i in range(N):
            n_bins = np_rng.integers(1, 200)
            probs = np_rng.dirichlet(np.ones(n_bins))
            bin_lprobs = torch.log(torch.tensor(probs, dtype=torch.float32))
            MAX_EDGE = 100
            two_edges = np_rng.uniform(0, MAX_EDGE, 2)
            lhs_edge, rhs_edge = min(two_edges), max(two_edges)
            B = 128
            second_last_edge = rhs_edge - (rhs_edge - lhs_edge) / n_bins
            # Only queries in the last bin (after second last edge).
            qs = (
                torch.tensor(
                    np_rng.exponential(scale=3, size=(2, B)),
                    dtype=torch.float32,
                )
                + second_last_edge
            )
            froms = torch.min(qs, axis=0).values
            tos = torch.max(qs, axis=0).values
            lprobs = kdtpp.prob.interval_prob3(
                einops.repeat(bin_lprobs, "n -> b n", b=B),
                lhs_edge,
                rhs_edge,
                froms,
                tos,
            )
            integral = torch.exp(second_last_edge - froms) - torch.exp(
                second_last_edge - tos
            )
            expected_prob = integral * probs[-1]
            assert torch.all(expected_prob >= 0)
            assert torch.allclose(
                torch.exp(lprobs), expected_prob, atol=1e-7
            ), f"{torch.nonzero(torch.exp(lprobs) == expected_prob)} "

    test4b()

    def test5():
        # Test 5. Negative interval (a, b) with a > b throws an error.
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        froms = torch.tensor([10])
        tos = torch.tensor([9])
        with pytest.raises(ValueError):
            kdtpp.prob.interval_prob3(
                einops.repeat(bin_lprobs, "n -> b n", b=len(froms)),
                lhs_edge,
                rhs_edge,
                froms,
                tos,
            )

    test5()


def test_auto_ll(np_rng):
    """Tests the likelihood calc for the auto-binned discrete distribution.

    The behaviour being tested here is very similar to that tested by
    test_interval_prob3. This test is a bit simpler, as intervals don't need
    to be considered—just point probability densities.

    Tests that:

        1. A few hard-coded cases are correct.
        2. Randomized cases are correct.
    """

    def test1():
        # Test 1. Hard-coded cases.
        probs = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
        bin_lprobs = torch.log(probs)
        lhs_edge, rhs_edge = 5, 10
        # fmt: off
        target_ans = torch.tensor(
            [
                [0, 1/6 * 0.2],
                [1, 1/6 * 0.2],
                [2.5, 1/6 * 0.2],
                [5, 1/6 * 0.2],
                [6, 0.2],
                [6.5, 0.2],
                [7, 0.2],
                [9, 0.2],
                [10, math.exp(-1) * 0.2],
                [12, math.exp(-3) * 0.2],
                [14.5, math.exp(-5.5) * 0.2],
            ]
        )
        # fmt: on
        lprobs = kdtpp.prob.auto_ll(
            m_out=einops.repeat(bin_lprobs, "n -> b n", b=len(target_ans)),
            target_t=target_ans[:, 0],
            t_min=lhs_edge,
            t_max=rhs_edge,
            n_bins=len(probs),
        )
        probs = torch.exp(lprobs)
        expected_probs = target_ans[:, 1]
        assert torch.allclose(probs, expected_probs, atol=1e-7)

    test1()

    def test2():
        # Test 2. Randomized cases.
        N = 100
        for i in range(N):
            n_bins = np_rng.integers(1, 200)
            probs = torch.tensor(
                np_rng.dirichlet(np.ones(n_bins)), dtype=torch.float32
            )
            bin_lprobs = torch.log(torch.tensor(probs, dtype=torch.float32))
            MAX_EDGE = 100
            two_edges = np_rng.uniform(0, MAX_EDGE, 2)
            lhs_edge, rhs_edge = min(two_edges), max(two_edges)
            bin_len = (rhs_edge - lhs_edge) / n_bins
            B = 128
            # Query any point in [0, rhs*2]
            targets = torch.tensor(
                # np_rng.uniform(0, rhs_edge * 2, (B,)),
                np_rng.uniform(0, rhs_edge * 2, (B,)),
                dtype=torch.float32,
            )
            lprobs = kdtpp.prob.auto_ll(
                m_out=einops.repeat(bin_lprobs, "n -> b n", b=B),
                target_t=targets,
                t_min=lhs_edge,
                t_max=rhs_edge,
                n_bins=n_bins,
            )
            t_idx = torch.floor((targets - lhs_edge) / bin_len).long()
            t_idx = torch.clamp(t_idx, 0, n_bins - 1)
            expected_lprobs = torch.where(
                targets < (lhs_edge + bin_len),
                torch.log(probs[0]) - math.log(bin_len + lhs_edge),
                torch.where(
                    targets > (rhs_edge - bin_len),
                    rhs_edge - bin_len - targets + torch.log(probs[-1]),
                    torch.log(probs[t_idx]) - math.log(bin_len),
                ),
            )
            assert torch.allclose(lprobs, expected_lprobs, atol=1e-7)

    test2()


def test_to_bin_idx():
    """Test the little (query, edges) -> bin_idx function."""
    bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])
    # fmt: off
    queries = torch.tensor(
        [
            [0, 0],
            [4.99, 0],
            [5, 1],
            [5.001, 1],
            [5.5, 1],
            [6, 2],
            [6.5, 2],
            [7.5, 2],
            [8, 3],
            [8.5, 3],
            [9.5, 3],
            [9.999, 3],
            [10, 4],
            [10.2, 4],
            [1000, 4],
        ]
    )
    # fmt: on
    expected_bin_idxs = queries[:, 1].long()
    bin_idxs = kdtpp.prob.to_bin_idx(queries[:, 0], bin_edges)
    assert torch.allclose(bin_idxs, expected_bin_idxs)


def test_var_bin_expected_val():
    """Test expected value calculation for a variable bin width distribution.

    Tests that:
        1. A few hard-coded cases are correct.
    """

    def test1():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])

        # fmt: off
        # Not necessarily normalized log probabilities.
        logits = torch.tensor([
            [0.1, 0.2, 0.4, 0.2, 0.1], 
            [0.2, 0.4, 0.8, 0.4, 0.2], 
            [1.0, 0.0, 0.0, 0.0, 0.0], 
            [0.0, 1.0, 0.0, 0.0, 0.0], 
            [0.0, 0.0, 1.0, 0.0, 0.0], 
            [0.0, 0.0, 0.0, 1.0, 0.0], 
            [0.0, 0.0, 0.0, 0.0, 1.0], 
            [2.0, 0.0, 0.0, 0.0, 0.0], 
            [0.0, 0.5, 0.0, 0.0, 0.0], 
            [1.0, 1.0, 1.0, 1.0, 1.0], 
            [1e-7, 1e-7, 1e-7, 1e-7, 1e-7], 
            [1e7, 1e7, 1.0, 1.0, 1.0],
        ]).log()
        # fmt: on
        def calc_ans(probs):
            res = (
                2.5 * probs[:, 0]
                + 5.5 * probs[:, 1]
                + 7 * probs[:, 2]
                + 9 * probs[:, 3]
                + (10 + 1) * probs[:, 4]
            )
            return res

        expected_ans = calc_ans(torch.nn.functional.softmax(logits, dim=-1))
        E = kdtpp.prob.var_bin_expected_val(logits, bin_edges)
        assert torch.allclose(E, expected_ans, atol=1e-7)

    test1()


def test_var_bin_mode():
    """Test the mode calculation for a variable bin width distribution.

    Tests that:

        1. A few hard-coded cases are correct.
    """

    def test1():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])

        # fmt: off
        # Not necessarily normalized log probabilities.
        logits = torch.tensor([
            [0.1, 0.2, 0.5, 0.1, 0.1],  # 7.0 | 0.5/2 in [6,8] -> 7
            [0.2, 0.4, 1.0, 0.1, 0.1],  # 7.0 | same, but unnormalized
            [0.8, 0.2, 0.1, 0.1, 0.1],  # 5.5 | 0.4/1 in [5,6] has heighest *density*
            [1.0, 0.0, 0.0, 0.0, 0.0],  # 2.5 | [0,5] -> 2.5
            [0.0, 1.0, 0.0, 0.0, 0.0],  # 5.5 | [5,6] -> 5.5
            [0.0, 0.0, 1.0, 0.0, 0.0],  # 7.0 | [6,8] -> 7
            [0.0, 0.0, 0.0, 1.0, 0.0],  # 9.0 | [8,10] -> 9
            [0.0, 0.0, 0.0, 0.0, 1.0],  # 10  | exp on [10,inf]  -> 10
            [2.0, 0.0, 0.0, 0.0, 0.0],  # 2.5 |
            [0.0, 0.5, 0.0, 0.0, 0.0],  # 5.5 | 
            [1.0, 1.0, 1.0, 1.0, 1.0],  # 5.5 | Not uniform density!
            [5.0, 1.0, 2.0, 2.0, 0.0],  # 2.5 | Uniform (except tail), take first mode
            [0.0, 1.0, 2.0, 2.0, 0.0],  # 5.5 | Uniform (except tail), take first mode
            [1e-7, 1e-7, 1e-7, 1e-7, 1e-7], # 5.5
            [1e7, 1e7, 1.0, 3.0, 1.0], # 5.5
            [1e7, 1e7, 1.0, 3e7, 1.0], # 9.0
        ]).log()
        ans = torch.tensor(
            [7.0, 7.0, 5.5, 2.5, 5.5, 7.0, 9.0, 10.0, 2.5, 5.5, 5.5, 2.5, 5.5, 5.5, 5.5, 9.0]
        )
        # fmt: on

        E = kdtpp.prob.var_bin_mode(logits, bin_edges)
        assert torch.allclose(E, ans, atol=1e-7)

    test1()


def test_var_bin_median():
    """Test the median calculation for a variable bin width distribution.

    Tests that:

        1. A few hard-coded cases are correct.
    """

    def test1():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])

        # fmt: off
        # Not necessarily normalized log probabilities.
        I = float("inf")
        logits = torch.tensor([
            # Single bin has preponderance of mass.
            [1.0,  -I,  -I,  -I,  -I],
            [ -I, 0.0,  -I,  -I,  -I],
            [ -I,  -I, 3.0,  -I,  -I],
            [ -I,  -I,  -I,-4.0,  -I],
            [ -I,  -I,  -I,  -I,  20],
        ])
        ans = torch.tensor([
            # single bins
            2.5, 5.5, 7.0, 9.0, 10.0 + math.log(2),
            # two bins
        ])
        # fmt: on
        median = kdtpp.prob.var_bin_median(logits, bin_edges)
        # assert torch.allclose(median, ans, atol=1e-7)
        torch.testing.assert_close(median, ans)  # , atol=1e-7)

    test1()

    def test2():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])
        # Not necessarily normalized probabilities (not log)
        # fmt: off
        probs = torch.tensor([
            # Two bins, and the balance between them.
            [1.0, 1.0,   0,   0,   0], # 5 | median reached at end of first bin
            [  0, 1.0, 2.0,   0,   0], # 6.5  | median need 1/6 from [6,8] bin, which has total of 4/6
            [  0,   0, 6.0, 4.0,   0], # 7.6 | 5/6th of the way through [6,8]
            [  0,   0,   0, 1.0, 3.0], # -ln(1-1/3) = 10.405465108 | 1/3 of the way through [10,inf]
            [1.0,   0,   0,   0, 3.0], # same!
        ])
        ans = torch.tensor([5.0, 6.5, 6 + 2*5/6, 10 + -math.log(1-1/3), 10 + -math.log(1-1/3)])
        # fmt: on
        median = kdtpp.prob.var_bin_median(torch.log(probs), bin_edges)
        torch.testing.assert_close(median, ans)

    test2()

    def test3():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])
        # Not necessarily normalized probabilities (not log)
        # fmt: off
        probs = torch.tensor([
            # Three bins.
            [1.0, 1.0, 1.0,   0,   0], # 5.5 | half way through [5,6]
            [  0, 2.0, 1.0, 1.0,   0], # 6   | reached at end of [5,6]
            [1.0,   0, 2.0, 7.0,   0], # 8.571 | 8 + 2/7 * 2
            [  0, 1.0, 100,   0, 399], # 10.4675  | 149/399 into [10,inf]
            [0.1, 0.1,   0,   0, 0.3], # 10.1823 | 0.05/0.3 into [10, inf]
        ])
        ans = torch.tensor([5.5, 6, 8 + 2/7*2, 10 + -math.log(1 - 149/399), 
                            10 + -math.log(1-0.05/0.3)])
        # fmt: on
        median = kdtpp.prob.var_bin_median(torch.log(probs), bin_edges)
        torch.testing.assert_close(median, ans)

    test3()


def test_var_bin_cdf():
    """Test cdf calculation for a variable bin width distribution.

    Tests that:
        1. A few hard-coded cases are correct.
    """

    def test1():
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])

        # Not necessarily normalized log probabilities.
        logits = torch.tensor([0.1, 0.2, 0.4, 0.2, 0.1]).log()
        queries_ans = torch.tensor(
            [
                [0, 0.0],
                [0.05, 0.05/5 * 0.1],
                [0.25, 0.25/5 * 0.1],
                [2.5,  2.5/5 * 0.1],
                [4.99, 4.99/5 * 0.1],
                [5, 0.1],
                [5.5, 0.1 + 0.2 * 0.5 / 1],
                [6, 0.3],
                [6.2, 0.3 + 0.4 * 0.2 / 2],
                [9.8, 0.7 + 1.8 * 0.2 / 2],
                [10, 0.9],
                [10.1, 0.9 + 0.1*(1 - math.exp(-0.1))],
                [11, 0.9 + 0.1*(1 - math.exp(-1))],
                [15, 0.9 + 0.1*(1 - math.exp(-5))],
            ]
        )
        xs = queries_ans[:, 0]
        expected_probs = queries_ans[:, 1]
        probs = kdtpp.prob.var_bin_cdf(
            einops.repeat(logits, "n -> b n", b=len(xs)), bin_edges, xs
        )
        torch.testing.assert_close(probs, expected_probs)

    test1()


def test_var_bin_ll():
    """
    Tests the log-likelihood calculation for a variable bin width distribution.

    Tests that:

        1. A few hard-coded cases are correct.
    """

    def test1():
        probs = torch.tensor([0.1, 0.2, 0.4, 0.2, 0.1])
        bin_edges = torch.tensor([0, 5, 6, 8, 10, float("inf")])
        # fmt: off
        target_ans = torch.tensor(
            [
                [0, 0.1 * 1/5],
                [1, 0.1 * 1/5],
                [4.5, 0.1 * 1/5],
                [5, 0.2 * 1/1],
                [5.5, 0.2 * 1/1],
                [5.999, 0.2 * 1/1],
                [6, 0.4 * 1/2],
                [6.5, 0.4 * 1/2],
                [7.5, 0.4 * 1/2],
                [7.999, 0.4 * 1/2],
                [8, 0.2 * 1/2],
                [8.5, 0.2 * 1/2],
                [9.5, 0.2 * 1/2],
                [9.999, 0.2 * 1/2],
                [10, 0.1 * math.exp(-0)],
                [10.2, 0.1 * math.exp(-0.2)],
                [11, 0.1 * math.exp(-1)],
                [12, 0.1 * math.exp(-2)],
                [15, 0.1 * math.exp(-5)],
                [20.1, 0.1 * math.exp(-10.1)],
            ]
        )
        # fmt: on
        lprobs = kdtpp.prob.var_bin_ll(
            m_out=einops.repeat(
                torch.log(probs), "n -> b n", b=len(target_ans)
            ),
            bin_edges=bin_edges,
            target_t=target_ans[:, 0],
        )
        probs = torch.exp(lprobs)
        expected_probs = target_ans[:, 1]
        assert torch.allclose(probs, expected_probs, atol=1e-7)

    test1()
