import unittest as ut

import numpy as np

from comab.reward_estimator.reward_estimator import F_hat_combi


class TestFHatCombi(ut.TestCase):
    def test_one_update(self):
        # init
        cases = [
            (np.array([[0., 0.5, 0, 0, 0, 0.8, 0]]),  np.array([[0, 1, 0, 0, 0, 2, 0]]), np.array([[np.nan, 0.5, 0.25, (0.125+2*np.power(0.8, 3./5))/3, (np.power(0.5, 4.)+2*np.power(0.8, 4./5))/3, 0.8, np.power(0.8, 6./5)]])),
        ]

        for F_hat, t, expected in cases:
            with self.subTest(msg="Checking assignment", F_hat=F_hat, t=t, expected=expected):
                # init

                # run
                result = F_hat_combi(F_hat, t)

                # check
                np.testing.assert_allclose(result, expected, equal_nan=True)
