import unittest
from MultiSM import MultiStateMachine
from generate_combs import generate_all_combs
from primary_backup.PrimaryBackupEnv import PrimaryBackupEnv
from primary_backup.State import State
from primary_backup.PrimaryBackup import reset_envs, step_envs


class TestPrimaryBackup(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.players = 3
        multi_sm = MultiStateMachine(cls.players, "primary_backup")
        cls.envs = [
            PrimaryBackupEnv(cls.players, i, 2, multi_sm, is_training=True)
            for i in range(cls.players)
        ]
        all_combs = generate_all_combs(players=3, num_round=2)
        '''
        initial: 5 [6, 5, 6]
        round: 1, crash: [2], alive: [0, 1], receive: ((0,), (1,)) 
        [0, 6, 7, 5, 7, 4, 7]
        [1, 6, 7, 5, 7, 6, 7] -> x
        round: 2, crash: [0], alive: [1], receive: ((0,),)
        [1, 4, 6, x, 5, 4, 6]
        '''
        cls.setting = all_combs[556]

    def test_reset_envs(self):
        reset_envs(self.envs, self.players, self.setting)
        self.assertEqual(self.envs[0].initial_states, [6, 5, 6])
        self.assertEqual(self.envs[1].initial_states, [6, 5, 6])

        self.assertEqual(self.envs[0].is_crash, False)
        self.assertEqual(self.envs[2].is_crash, True)

        self.assertEqual(self.envs[0].states, [0, 6, 7, 5, 7, 4, 7])
        self.assertEqual(self.envs[1].states, [1, 6, 7, 5, 7, 6, 7])

        self.assertEqual(self.envs[0].crash_nodes, {2})

    def test_step_envs(self):
        actions = [State.DoNothing_Zero.value, State.DoNothing_One.value, State.Lost.value]
        step_envs(self.envs, actions, self.players, self.setting)
        for env in self.envs:
            env.current_round += 1
        
        self.assertEqual(self.envs[0].is_crash, True)

        self.assertEqual(self.envs[1].states, [1, 4, 6, State.DoNothing_One.value, 5, 4, 6])

        self.assertEqual(self.envs[0].crash_nodes, {0, 2})

        self.assertEqual(self.envs[0]._done, False)

        actions = [State.Lost.value, State.One.value, State.Lost.value]
        step_envs(self.envs, actions, self.players, self.setting)
        for env in self.envs:
            env.current_round += 1
        rewards = [env.reward() for env in self.envs]

        self.assertEqual(self.envs[0]._done, True)
        self.assertEqual(rewards[1], 1) # Only need to check the alive one


