import unittest as ut

import numpy as np

from comab.reward_estimator.fixed_grid_estimator import WithFixedGridEstimation


class TestFHat(ut.TestCase):
    def test_init(self):
        # init
        estimator = WithFixedGridEstimation(K=2, N=5, p=np.array([2,3],dtype=int), D=10, R=2)

        # run
        reward = estimator.r_hat()

        # check
        np.testing.assert_allclose(reward, np.nan, equal_nan=True)

    def test_one_update(self):
        # init
        cases = [
            (np.array([2,3]), np.array([False, True]), np.array([0.0, 1.3]), np.array([[np.nan]*8,[np.nan]*3 + [0.0]*5])),
            (np.array([3,1]), np.array([True, True]), np.array([0.54, 0.39]), np.array([[np.nan]*3 + [0.0]*5, [np.nan]*2 + [0.0] * 6])),
        ]

        for n, arms, gains, result in cases:
            with self.subTest(msg="Checking assignment", n=n, arms=arms, gains=gains, result=result):
                # init
                estimator = WithFixedGridEstimation(K=2, N=5, p=np.array([2,3],dtype=int), D=10, R=2)

                # run
                estimator.update_estimator(n, arms, gains, None, None)
                reward = estimator.r_hat()

                # check
                np.testing.assert_allclose(reward, result, equal_nan=True)
