import unittest
from src.searchlight.algorithms.mcts_search import SMMonteCarlo
from src.searchlight.datastructures.graphs import ValueGraph2
from tests.Search.initial_inferencers import INFERENCERS 
from src.searchlight.headers import *
from src.searchlight.datastructures.adjusters import *
from src.searchlight.datastructures.estimators import *
import numpy as np
import tqdm

class TestSearch(unittest.TestCase):

    rng = np.random.default_rng(12)

    def test_expand_three_state(self):
        # create graph
        graph = ValueGraph2(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=self.rng, players={0})
        # create search algorithm
        search = SMMonteCarlo(INFERENCERS['three_state'](), rng=self.rng, node_budget=10, num_rollout=10)
        # run search on root state (0)
        state = State(0)
        search.expand(graph, state)

        # get root node
        node0 = graph.get_node(state)

        # assert that the graph has 3 nodes
        self.assertEqual(len(graph.id_to_node.items()), 3)

        # assert that state(0) has 2 children
        self.assertEqual(len(graph.get_node(state).children), 2)

        # assert that state(0) has value estimate of 1.0 for actor 0
        self.assertEqual(graph.get_estimated_value(node0, 0), 1.0)

    def test_expand_five_state(self):
        # create graph
        graph = ValueGraph2(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=self.rng, players={0})
        # create search algorithm
        search = SMMonteCarlo(INFERENCERS['five_state'](), rng=self.rng, node_budget=10, num_rollout=10)
        # run search on root state (0)
        state = State(0)
        search.expand(graph, state)

        # get root node
        node0 = graph.get_node(state)

        # assert that the graph has 5 nodes
        # print('nodes', graph.id_to_node.items())
        self.assertEqual(len(graph.id_to_node.items()), 3)

        # assert that state(0) has 2 children
        self.assertEqual(len(graph.get_node(state).children), 2)

        # assert that state(0) has value estimate of 3.0 for actor 0
        self.assertEqual(graph.get_estimated_value(node0, 0), 2.0)

    def test_expand_four_chain(self):
        # create graph
        graph = ValueGraph2(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=self.rng, players={0})
        # create search algorithm
        search = SMMonteCarlo(INFERENCERS['four_chain'](), rng=self.rng, node_budget=10, num_rollout=10)
        # run search on root state (0)
        state = State(0)
        search.expand(graph, state)

        # get root node
        node0 = graph.get_node(state)

        # assert that the graph has 4 nodes
        self.assertEqual(len(graph.id_to_node.items()), 4)

        # assert that state(0) has 1 children
        self.assertEqual(len(graph.get_node(state).children), 1)

        # assert that state(0) has value estimate of 3.0 for actor 0
        self.assertEqual(graph.get_estimated_value(node0, 0), 4.0)

    def test_expand_seven_state(self):
        # create graph
        graph = ValueGraph2(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=self.rng, players={0})
        # create search algorithm
        search = SMMonteCarlo(INFERENCERS['seven_state'](), rng=self.rng, node_budget=1, num_rollout=10)
        # run search on root state (0)
        state = State(0)
        search.expand(graph, state)

        # get root node
        node0 = graph.get_node(state)

        # assert that the graph has 7 nodes
        self.assertEqual(len(graph.id_to_node.items()), 5) #TODO: why is this 5?

        # assert that state(0) has 2 children
        self.assertEqual(len(graph.get_node(state).children), 2)

        # assert that state(0) has value estimate of 4.0 for actor 0
        self.assertEqual(graph.get_estimated_value(node0, 0), 4.0)





