"""Tests for `seq_halving.py`."""

from absl.testing import absltest

from _mctx._src import seq_halving


class SeqHalvingTest(absltest.TestCase):
    def _check_visits(
        self, expected_results, max_num_considered_actions, num_simulations
    ):
        """Compares the expected results to the returned considered visits."""
        self.assertLen(expected_results, num_simulations)
        results = seq_halving.get_sequence_of_considered_visits(
            max_num_considered_actions, num_simulations
        )
        self.assertEqual(tuple(expected_results), results)

    def test_considered_min_sims(self):
        # Using exactly `num_simulations = max_num_considered_actions *
        #   log2(max_num_considered_actions)`.
        num_sims = 24
        max_num_considered = 8
        expected_results = [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,  # Considering 8 actions.
            1,
            1,
            1,
            1,  # Considering 4 actions.
            2,
            2,
            2,
            2,  # Considering 4 actions, round 2.
            3,
            3,
            4,
            4,
            5,
            5,
            6,
            6,  # Considering 2 actions.
        ]  # pyformat: disable
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_extra_sims(self):
        # Using more simulations than `max_num_considered_actions *
        #   log2(max_num_considered_actions)`.
        num_sims = 47
        max_num_considered = 8
        expected_results = [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,  # Considering 8 actions.
            1,
            1,
            1,
            1,  # Considering 4 actions.
            2,
            2,
            2,
            2,  # Considering 4 actions, round 2.
            3,
            3,
            3,
            3,  # Considering 4 actions, round 3.
            4,
            4,
            5,
            5,
            6,
            6,
            7,
            7,
            8,
            8,
            9,
            9,
            10,
            10,
            11,
            11,
            12,
            12,
            13,
            13,
            14,
            14,
            15,
            15,
            16,
            16,
            17,
        ]  # pyformat: disable
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_less_sims(self):
        # Using a very small number of simulations.
        num_sims = 2
        max_num_considered = 8
        expected_results = [0, 0]
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_less_sims2(self):
        # Using `num_simulations < max_num_considered_actions *
        #   log2(max_num_considered_actions)`.
        num_sims = 13
        max_num_considered = 8
        expected_results = [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,  # Considering 8 actions.
            1,
            1,
            1,
            1,  # Considering 4 actions.
            2,
        ]  # pyformat: disable
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_not_power_of_2(self):
        # Using max_num_considered_actions that is not a power of 2.
        num_sims = 24
        max_num_considered = 7
        expected_results = [
            0,
            0,
            0,
            0,
            0,
            0,
            0,  # Considering 7 actions.
            1,
            1,
            1,
            2,
            2,
            2,  # Considering 3 actions.
            3,
            3,
            4,
            4,
            5,
            5,
            6,
            6,
            7,
            7,
            8,
        ]  # pyformat: disable
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_action0(self):
        num_sims = 16
        max_num_considered = 0
        expected_results = range(num_sims)
        self._check_visits(expected_results, max_num_considered, num_sims)

    def test_considered_action1(self):
        num_sims = 16
        max_num_considered = 1
        expected_results = range(num_sims)
        self._check_visits(expected_results, max_num_considered, num_sims)


if __name__ == "__main__":
    absltest.main()
