import numpy as np
from numpy.testing import assert_allclose
import pytest

from offline.lbp.utils import compute_discounted_returns


@pytest.mark.parametrize("gamma", np.linspace(0, 1, 11))
def test_compute_discounted_returns(gamma):
    rewards = np.asarray([1, 2, 3, 4, 5, 6, 7], dtype=np.float32)
    dones = np.asarray([False, False, True, False, False, True, True])
    actual = compute_discounted_returns(dones, gamma, rewards)
    desired = [
        1 + 2 * gamma + 3 * gamma**2,
        2 + 3 * gamma,
        3,
        4 + 5 * gamma + 6 * gamma**2,
        5 + 6 * gamma,
        6,
        7,
    ]
    assert_allclose(actual, desired)
