import unittest
from src.Avalon.baseline_models_Avalon import *
from src.searchlight.headers import *
from typing import Optional
from src.Avalon.engine import *
import itertools


class TestAvalonTransistor(unittest.TestCase):
    # TODO: this needs to be updated

    def test_single(self):
        num_player = 6  # try 5, 6, 7, 8, 9, 10

        env = AvalonGameEnvironment.from_num_players(num_player)
        config = env.config

        # quest_leader = env.quest_leader
        # phase = env.phase
        # turn = env.turn
        # round = env.round
        # done = env.done

        # # make the following tuples so that they are immutable
        # quest_team = frozenset(env.quest_team)
        # team_votes = tuple(env.team_votes)
        # quest_votes = tuple(env.quest_votes)
        # quest_results = tuple(env.quest_results)
        # roles = tuple(env.roles)

        state = AvalonState.init_from_env(env)

        Avalon_transistor = AvalonTransitor(env)


        # phase 0
        print('-------------------------------------------------------------------')
        print('Phase {}'.format(state.phase))
        print('-------------------------------------------------------------------')
        actions_dict = {}
        quest_size = config.num_players_for_quest[state.turn]
        actions = list(itertools.combinations(set(range(num_player)), quest_size))
        print(actions)
        rd = np.random.randint(len(actions))
        print(actions[rd])
        actions_dict[state.quest_leader] = frozenset(actions[rd])
        next_state, reward, _ = Avalon_transistor._transition(state, actions_dict)
        print(state)
        print(next_state)
        # self.assertEqual(state.quest_leader+1, next_state.quest_leader)
        self.assertEqual(next_state.phase, 1)
        self.assertEqual(next_state.quest_team, tuple(sorted(actions[rd])))



        # action for phase 1
        state.phase = 1
        print('-------------------------------------------------------------------')
        print('Phase {}'.format(state.phase))
        print('-------------------------------------------------------------------')
        actions_dict = {}
        for i in range(num_player):
            if np.random.rand() > 0.5:
                actions_dict[i] = 1
            else:
                actions_dict[i] = 0

        next_state, reward, _ = Avalon_transistor._transition(state, actions_dict)
        print(state)
        print(next_state)
        sum_vote = 0
        for value in actions_dict.values():
            sum_vote += value / num_player
        if sum_vote > 0.5:
            self.assertEqual(next_state.phase, 2)
        else:
            self.assertEqual(next_state.phase, 0)



        # action for phase 2
        state.phase = 2
        print('-------------------------------------------------------------------')
        print('Phase {}'.format(state.phase))
        print('-------------------------------------------------------------------')
        actions_dict = {}
        quest_size = config.num_players_for_quest[state.turn]
        players = [i for i in range(num_player)]
        random_quest = np.random.choice(players, size=quest_size, replace=False)
        for quest in random_quest:
            if np.random.rand() > 0.5:
                actions_dict[quest] = 0
            else:
                actions_dict[quest] = 1
        next_state, reward, _ = Avalon_transistor._transition(state, actions_dict)
        print(state)
        print(next_state)
        if sum(next_state.quest_votes) / quest_size == 1:
            self.assertEqual(list(next_state.quest_results)[0], True)
        else:
            self.assertEqual(list(next_state.quest_results)[0], False)



        # action for phase 3
        state.phase = 3
        quest_results = []
        for i in range(num_player):
            if np.random.rand() > 0.5:
                quest_results.append(1)
            else:
                quest_results.append(0)
        state.quest_results = quest_results
        print('quest_results', state.quest_results)
        print('-------------------------------------------------------------------')
        print('Phase {}'.format(state.phase))
        print('-------------------------------------------------------------------')

        players = [i for i in range(num_player)]
        assassin_target = np.random.choice(players, size=1, replace=False)
        actions = {state.get_assassin(): assassin_target}
        (next_state, reward, notes) = Avalon_transistor._transition(state, actions)
        print(state)
        print(next_state)
        print()

        print('Good victory is {}'.format(env.good_victory))
        print(state.get_is_good())
        print('reward = ', reward)
        if env.good_victory:
            for i in range(num_player):
                if state.get_is_good()[i]:
                    self.assertEqual(reward[i], 1)
                else:
                    self.assertEqual(reward[i], -1)
        else:
            for i in range(num_player):
                if state.get_is_good()[i]:
                    self.assertEqual(reward[i], -1)
                else:
                    self.assertEqual(reward[i], 1)

if __name__ == "__main__":
    unittest.main()