import unittest
from src.searchlight.gameplay.simulators import DialogueGameSimulator
from src.Avalon.baseline_models_Avalon import *
from src.Avalon.engine import AvalonBasicConfig, AvalonGameEnvironment
from src.searchlight.gameplay.agents import MCTSAgent, MuteMCTSAgent, HumanDialogueAgent
from src.searchlight.classic_models import RandomRolloutValueHeuristic, ZeroValueHeuristic
# from src.searchlight.gameplay.agents import SearchAgent
# from src.searchlight.algorithms.mcts_search import SMMonteCarlo
# from src.searchlight.datastructures.graphs import ValueGraph
# from src.searchlight.datastructures.adjusters import QValueAdjuster
# from src.searchlight.datastructures.estimators import UtilityEstimatorLast

class TestSimulators(unittest.TestCase):

    # first create a GOPS game simulator
    

    def test_simulate_games_random_agent(self):
        # create config
        avalon_config = AvalonBasicConfig.from_num_players(5)
        avalon_env = AvalonGameEnvironment(avalon_config)
        start_state = AvalonState.init_from_env(avalon_env)

        actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=avalon_env)
        forward_transitor = AvalonTransitor(env=avalon_env)
        speaker_enumerator = AvalonSpeakerEnumerator(avalon_env=avalon_env)
        information_function = AvalonInformationFunction(config=avalon_config)



        # create game simulator
        simulator = DialogueGameSimulator(transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, speaker_enumerator=speaker_enumerator, information_function=information_function, start_state=start_state)

        # create random agents
        rng = np.random.default_rng(12)
        random_agents = dict()
        for player in list(range(avalon_config.num_players)):
            random_agents[player] = RandomDialogueAgent(rng)
        random_agents[-1] = RandomDialogueAgent(rng)

        # simulate games
        num_games = 10
        avg_scores, trajectories = simulator.simulate_games(random_agents, num_games, display=False)

        # print results
        # print('Player 1 win rate:', avg_scores[0])
        # print('Player 2 win rate:', avg_scores[1])
        # print('trajectories', trajectories)

    # def test_simulate_games_human_agent(self):
    #     # create config
    #     avalon_config = AvalonBasicConfig.from_num_players(5)
    #     avalon_env = AvalonGameEnvironment(avalon_config)
    #     start_state = AvalonState.init_from_env(avalon_env)

    #     actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=avalon_env)
    #     forward_transitor = AvalonTransitor(env=avalon_env)
    #     speaker_enumerator = AvalonSpeakerEnumerator(avalon_env=avalon_env)
    #     information_function = AvalonInformationFunction(config=avalon_config)

    #     # create game simulator
    #     simulator = DialogueGameSimulator(transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, speaker_enumerator=speaker_enumerator, information_function=information_function, start_state=start_state)

    #     # create 1 human agent, and fill the rest with random agents
    #     rng = np.random.default_rng(12)
    #     human_agents = dict()
    #     for player in list(range(avalon_config.num_players)):
    #         if player == 0:
    #             human_agents[player] = HumanDialogueAgent(player)
    #         else:
    #             human_agents[player] = RandomDialogueAgent(rng)

    #     # simulate games
    #     num_games = 1
    #     avg_scores, trajectories = simulator.simulate_games(human_agents, num_games, display=False)


    def test_simulate_games_mcts_agent(self):
        # create config
        avalon_config = AvalonBasicConfig.from_num_players(5)
        avalon_env = AvalonGameEnvironment(avalon_config)
        start_state = AvalonState.init_from_env(avalon_env)
        players = set(range(avalon_config.num_players))

        actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=avalon_env)
        forward_transitor = AvalonTransitor(env=avalon_env)
        speaker_enumerator = AvalonSpeakerEnumerator(avalon_env=avalon_env)
        information_function = AvalonInformationFunction(config=avalon_config)
        
        
        # create game simulator
        simulator = DialogueGameSimulator(transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, speaker_enumerator=speaker_enumerator, information_function=information_function, start_state=start_state)

        # create MCTS agents
        rng = np.random.default_rng(12)
        # value_heuristic = RandomRolloutValueHeuristic(players=players, actor_action_enumerator=actor_action_enumerator, forward_transitor=forward_transitor, num_rollouts=2, rng=rng)
        value_heuristic = ZeroValueHeuristic()
        information_prior = AvalonInformationPrior(config=avalon_config, rng=rng)

        mcts_agents = dict()
        for player in list(range(avalon_config.num_players)):
            mcts_agents[player] = MuteMCTSAgent(players=players, player=player, forward_transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, value_heuristic=value_heuristic, information_function=information_function, information_prior=information_prior, num_rollout=10, node_budget=10, rng=rng)

        # simulate games
        num_games = 10
        avg_scores, trajectories = simulator.simulate_games(mcts_agents, num_games, display=False)

        # print results
        # print('Player 1 win rate:', avg_scores[0])
        # print('Player 2 win rate:', avg_scores[1])
        # print('trajectories', trajectories)

    # def test_simulate_games2(self):
    #     # create config
    #     action_enumerator = GOPSActionEnumerator()
    #     forward_transitor = GOPSForwardTransitor2()
    #     actor_enumerator = GOPSActorEnumerator()
    #     value_heuristic_1 = RandomRolloutValueHeuristic(actor_enumerator, action_enumerator,
    #                                                     forward_transitor, num_rollouts=10, 
    #                                                     rng=np.random.default_rng(12))
    #     value_heuristic_2 = RandomRolloutValueHeuristic(actor_enumerator, action_enumerator,
    #                                                     forward_transitor, num_rollouts=10, 
    #                                                     rng=np.random.default_rng(12))

    #     initial_inferencer_1 = GOPSInitialInferencer2(forward_transitor, action_enumerator, 
    #                                                 PolicyPredictor(), actor_enumerator, 
    #                                                 value_heuristic_1)
        
    #     initial_inferencer_2 = GOPSInitialInferencer2(forward_transitor, action_enumerator,
    #                                                 PolicyPredictor(), actor_enumerator,
    #                                                 value_heuristic_2)
        
    #     # create search
    #     search_1 = SMMonteCarlo(initial_inferencer=initial_inferencer_1, rng=np.random.default_rng(12), node_budget=16, num_rollout=32)
    #     search_2 = SMMonteCarlo(initial_inferencer=initial_inferencer_2, rng=np.random.default_rng(12), node_budget=16, num_rollout=32)

    #     # create graphs
    #     graph_1 = ValueGraph(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=np.random.default_rng(12), players={0,1})
    #     graph_2 = ValueGraph(adjuster=QValueAdjuster(), utility_estimator=UtilityEstimatorLast(), rng=np.random.default_rng(12), players={0,1})

    #     # create agents
    #     agent1 = SearchAgent(search_1, graph_1, player=0)
    #     agent2 = SearchAgent(search_2, graph_2, player=1)

    #     agents = {0: agent1, 1: agent2, -1: RandomAgent(np.random.default_rng(12))}

    #     # create game simulator
    #     simulator = GameSimulator(forward_transitor, actor_enumerator, action_enumerator, GOPS_START_STATE_6)

    #     # simulate games
    #     num_games = 10
    #     avg_scores, trajectories = simulator.simulate_games(agents, num_games)

    #     # print results
    #     print('Player 1 win rate:', avg_scores[0])
    #     print('Player 2 win rate:', avg_scores[1])
    #     print('trajectories', trajectories)

        
