import unittest
from src.Avalon.baseline_models_Avalon import *
from src.searchlight.headers import *
from typing import Optional
from src.Avalon.engine import *


class TestAvalonAction(unittest.TestCase):

    def test_single(self):
        num_player = 5  # try 5, 6, 7, 8, 9, 10
        env = AvalonGameEnvironment.from_num_players(num_player)
        config = env.config

        quest_leader = env.quest_leader
        phase = env.phase
        turn = env.turn
        round = env.round
        done = env.done
        # self.good_victory = env.good_victory

        # make the following tuples so that they are immutable
        quest_team = tuple(sorted(env.quest_team))
        team_votes = tuple(env.team_votes)
        quest_votes = tuple(env.quest_votes)
        quest_results = tuple(env.quest_results)
        roles = tuple(env.roles)

        state = AvalonState(config, quest_leader, phase, turn, round, done, env.good_victory, quest_team, team_votes, quest_votes, quest_results, roles, acting_player=quest_leader, simultaneous_actions=tuple())


        # phase 0
        state.phase = 0
        state.turn = 0  # try 0, 1, 2, 3, 4
        Avalon_actor = AvalonActorEnumerator()
        actor = Avalon_actor._enumerate(state)
        self.assertEqual(actor, {state.quest_leader})

        Avalon_action = AvalonActionEnumerator(env)
        for i in range(num_player):
            if i == state.quest_leader:
                action = Avalon_action._enumerate(state, i)

                num_play_for_quest = config.num_players_for_quest[state.turn]
                all_players = [i for i in range(num_player)]
                combinations = list(itertools.combinations(all_players, num_play_for_quest))
                combine_lst = []
                for combine in combinations:
                    combine_froze = frozenset(combine)
                    combine_lst.append(combine_froze)
                action_true = set(combine_lst)
                self.assertEqual(action, action_true)
            else:
                action = {}


        # phase 1
        state.phase = 1
        state.turn = 4  # try 0, 1, 2, 3, 4
        Avalon_actor = AvalonActorEnumerator()
        actors = Avalon_actor._enumerate(state)

        for i in range(num_player):
            actor = list(actors)[i]
            Avalon_action = AvalonActionEnumerator(env)
            action = Avalon_action._enumerate(state, actor)
            self.assertEqual(action, {0, 1})


        # phase 2
        state.phase = 2
        state.quest_team = tuple([0, 4])  # try  different quest team
        Avalon_actor = AvalonActorEnumerator()
        actors = Avalon_actor._enumerate(state)
        # print(actors)

        for actor in actors:
            Avalon_action = AvalonActionEnumerator(env)
            action = Avalon_action._enumerate(state, actor)
            self.assertEqual(action, {0, 1})


        # phase 3
        state.phase = 3
        Avalon_actor = AvalonActorEnumerator()
        actors = Avalon_actor._enumerate(state)

        for actor in actors:
            Avalon_action = AvalonActionEnumerator(env)
            action = Avalon_action._enumerate(state, actor)
            all_players = [i for i in range(num_player)]
            action_true = set(all_players)
            self.assertEqual(action, action_true)


if __name__ == "__main__":
    unittest.main()