import numpy as np
import unittest as ut

from comab.reward_estimator.full_empirical_cdf import F_hat_combi, F_hat

class TestFHat(ut.TestCase):
    def test_F_hat(self):
        # init
        cases = [
            (0.8,
             np.array([[0, 0, 2, 3, 0, 1]]),
             np.array(
                 [
                     [
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(0)
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(1)
                         [0.9, 0, 0, 0, 0.3, 0, 0],  # X_t^(2)
                         [0, 0.5, 0, 0.6, 0, 0.64, 0],  # X_t^(3)
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(4)
                         [0, 0, 0, 0.3, 0, 0, 0]  # X_t^(5)
                     ]
                 ]
             ),
             np.array([[np.nan, np.nan, 1./2, 1., np.nan, 1.]])
             ),
        ]

        for z, t, X, result in cases:
            with self.subTest(msg="Checking assignment", z=z, t=t, X=X, result=result):


                # run
                _F_hat_test = F_hat(z, X, t)

                # check
                np.testing.assert_allclose(_F_hat_test, result, equal_nan=True)

    def test_F_hat_combi(self):
        # init
        cases = [
            (0.8,
             np.array([[0, 0, 2, 3, 0, 1]]),
             np.array(
                 [
                     [
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(0)
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(1)
                         [0.9, 0, 0, 0, 0.3, 0, 0],  # X_t^(2)
                         [0, 0.5, 0, 0.6, 0, 0.64, 0],  # X_t^(3)
                         [0, 0, 0, 0, 0, 0, 0],  # X_t^(4)
                         [0, 0, 0, 0.3, 0, 0, 0]  # X_t^(5)
                     ]
                 ]
             ),
             np.array([[np.nan, 1 / np.sqrt(2), 4 / 5, (4 + 1 / np.sqrt(2)) / 6, 4.5 / 6,
                        (4 + 1 / 2 / np.sqrt(2)) / 6]])),
        ]

        for z, t, X, result in cases:
            with self.subTest(msg="Checking assignment", z=z, t=t, X=X, result=result):
                # run
                _F_hat_test = F_hat_combi(F_hat(z, X, t), t)

                # check
                np.testing.assert_allclose(_F_hat_test, result, equal_nan=True)