import unittest
from src.searchlight.gameplay.simulators import GameSimulator, DialogueGameSimulator
from src.GOPS.baseline_models_GOPS import *
from src.searchlight.classic_models import RandomRolloutValueHeuristic
from src.searchlight.gameplay.agents import SearchAgent
from src.searchlight.algorithms.mcts_search import SMMonteCarlo
from src.searchlight.datastructures.graphs import ValueGraph
from src.searchlight.datastructures.adjusters import QValueAdjuster
from src.searchlight.datastructures.estimators import UtilityEstimatorLast

class TestSimulators(unittest.TestCase):

    # first create a GOPS game simulator
    

    def test_simulate_games(self):
        # create config
        action_enumerator = GOPSActionEnumerator()
        forward_transitor = GOPSForwardTransitor2()
        actor_enumerator = GOPSActorEnumerator()

        # create game simulator
        simulator = GameSimulator(forward_transitor, actor_enumerator, action_enumerator, GOPS_START_STATE_6)

        # create random agents
        rng = np.random.default_rng(12)
        random_agents = dict()
        for player in {0,1}:
            random_agents[player] = RandomAgent(rng)
        random_agents[-1] = RandomAgent(rng)

        # simulate games
        num_games = 10
        avg_scores, trajectories = simulator.simulate_games(random_agents, num_games)

        # print results
        print('Player 1 win rate:', avg_scores[0])
        print('Player 2 win rate:', avg_scores[1])
        print('trajectories', trajectories)

    def test_simulate_games2(self):
        # create config
        action_enumerator = GOPSActionEnumerator()
        forward_transitor = GOPSForwardTransitor2()
        actor_enumerator = GOPSActorEnumerator()
        value_heuristic_1 = RandomRolloutValueHeuristic(actor_enumerator, action_enumerator,
                                                        forward_transitor, num_rollouts=10, 
                                                        rng=np.random.default_rng(12))
        value_heuristic_2 = RandomRolloutValueHeuristic(actor_enumerator, action_enumerator,
                                                        forward_transitor, num_rollouts=10, 
                                                        rng=np.random.default_rng(12))

        initial_inferencer_1 = GOPSInitialInferencer2(forward_transitor, action_enumerator, 
                                                    PolicyPredictor(), actor_enumerator, 
                                                    value_heuristic_1)
        
        initial_inferencer_2 = GOPSInitialInferencer2(forward_transitor, action_enumerator,
                                                    PolicyPredictor(), actor_enumerator,
                                                    value_heuristic_2)
        
        # create search
        search_1 = SMMonteCarlo(initial_inferencer=initial_inferencer_1, rng=np.random.default_rng(12), node_budget=16, num_rollout=32)
        search_2 = SMMonteCarlo(initial_inferencer=initial_inferencer_2, rng=np.random.default_rng(12), node_budget=16, num_rollout=32)

        # create graphs
        graph_1 = ValueGraph(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=np.random.default_rng(12), players={0,1})
        graph_2 = ValueGraph(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=np.random.default_rng(12), players={0,1})

        # create agents
        agent1 = SearchAgent(search_1, graph_1, player=0)
        agent2 = SearchAgent(search_2, graph_2, player=1)

        agents = {0: agent1, 1: agent2, -1: RandomAgent(np.random.default_rng(12))}

        # create game simulator
        simulator = GameSimulator(forward_transitor, actor_enumerator, action_enumerator, GOPS_START_STATE_6)

        # simulate games
        num_games = 10
        avg_scores, trajectories = simulator.simulate_games(agents, num_games)

        # print results
        print('Player 1 win rate:', avg_scores[0])
        print('Player 2 win rate:', avg_scores[1])
        print('trajectories', trajectories)

        
