import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
import unittest

from Enviroment import EnvManager
from generate_combs import Info, generate_all_combs
from primary_backup.State import State
from primary_backup.PrimaryBackup import reset_envs, step_envs


class TestEnvMgr(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.env_mgr = EnvManager(3, "primary_backup", 2, is_history=True, encode_id=True)
        all_combs = generate_all_combs(players=3, num_round=2)
        cls.setting: Info = all_combs[556]
        print(cls.setting)
        """
        initial: 5 [6, 5, 6]
        round: 1, crash: [2], alive: [0, 1], receive: ((0,), (1,)) 
        [0, 0, 6, 7, 5, 7, 4, 7] -> x
        [1, 0, 6, 7, 5, 7, 6, 7] -> y
        round: 2, crash: [0], alive: [1], receive: ((0,),)
        [1, 1, 4, 6, y, 5, 4, 6]
        """

    def init_test(self):
        self.env_mgr.init(self.setting)

    def test_step_back(self):
        self.init_test()
        self.assertEqual(self.env_mgr.envs[0].states, [0, 0, 6, 7, 5, 7, 4, 7])
        self.assertEqual(self.env_mgr.envs[1].states, [1, 0, 6, 7, 5, 7, 6, 7])
        self.assertEqual(self.env_mgr.get_crash_nodes(), {2})
        self.env_mgr.store()
        actions = [State.DoNothing_Zero.value, State.DoNothing_One.value, State.Lost.value]
        next_round = self.env_mgr.get_zero_based_round() + 1
        self.env_mgr.step(actions, self.setting.get_crash_info(next_round))
        self.assertEqual(self.env_mgr.get_rewards(), [0, 0, 0])
        actions = [State.Lost.value, State.Zero.value, State.Lost.value]
        next_round = self.env_mgr.get_zero_based_round() + 1
        self.env_mgr.step(actions, self.setting.get_crash_info(next_round))
        self.assertEqual(self.env_mgr.envs[0].state_machine.get_cur_state(), State.Lost.value)
        self.assertEqual(self.env_mgr.envs[1].state_machine.final_decision, State.Zero.value)
        self.env_mgr.step_back(2)
        self.env_mgr.restore()
        self.assertEqual(self.env_mgr.envs[0].state_machine.get_cur_state(), State.LocalOne.value)
        self.assertEqual(self.env_mgr.envs[1].state_machine.final_decision, None)
        self.assertEqual(self.env_mgr.get_crash_nodes(), {2})

    def test_store_restore(self):
        # Round 0
        self.init_test()
        self.assertEqual(self.env_mgr.envs[0].states, [0, 0, 6, 7, 5, 7, 4, 7])
        self.assertEqual(self.env_mgr.envs[1].states, [1, 0, 6, 7, 5, 7, 6, 7])
        self.assertEqual(self.env_mgr.get_crash_nodes(), {2})

        # Round 0 -> Round 1
        actions = [State.DoNothing_Zero.value, State.DoNothing_One.value, State.Lost.value]
        next_round = self.env_mgr.get_zero_based_round() + 1
        self.env_mgr.step(actions, self.setting.get_crash_info(next_round))
        self.assertEqual(self.env_mgr.envs[1].states, [1, 1, 4, 6, 3, 5, 4, 6])
        self.assertTrue(self.env_mgr.envs[0].is_crash)
        self.assertFalse(self.env_mgr.is_done())
        self.assertEqual(self.env_mgr.get_rewards(), [0, 0, 0])
        recevie_comb = self.env_mgr.get_cur_receive_comb()
        self.assertEqual(self.env_mgr.get_crash_nodes(), {0, 2})

        # Round 1 -> Round 2
        self.env_mgr.store()
        actions = [State.Lost.value, State.Zero.value, State.Lost.value]
        next_round = self.env_mgr.get_zero_based_round() + 1
        self.env_mgr.step(actions, self.setting.get_crash_info(next_round))
        rewards = self.env_mgr.get_rewards()  # Final rewards

        # Round 2 -> Round 1
        self.env_mgr.step_back(1)
        self.env_mgr.restore()
        self.assertEqual(self.env_mgr.round, 1)
        self.assertEqual(self.env_mgr.get_cur_receive_comb(), recevie_comb)
        self.assertEqual(self.env_mgr.envs[1].states, [1, 1, 4, 6, 3, 5, 4, 6])
        self.assertEqual(self.env_mgr.get_rewards(), [0, 0, 0])
        self.assertEqual(self.env_mgr.get_crash_nodes(), {0, 2})

        # Round 1 -> Round 2
        self.env_mgr.step(actions, self.setting.get_crash_info(self.env_mgr.get_zero_based_round() + 1))
        self.assertEqual(self.env_mgr.get_rewards(), rewards)
        self.assertEqual(self.env_mgr.round, 2)

    def test_zero_step_back(self):
        self.init_test()

        self.env_mgr.step_back(0)
        self.assertEqual(self.env_mgr.envs[0].states, [0, 0, 6, 7, 5, 7, 4, 7])
        self.assertEqual(self.env_mgr.envs[1].states, [1, 0, 6, 7, 5, 7, 6, 7])
        current_states, all_transitions = self.env_mgr.get_current_states()
        self.assertEqual(all_transitions[0], [State.LocalOne.value])

    def test_crash_nodes(self):
        self.init_test()


if __name__ == "__main__":
    suite = unittest.TestSuite()
    suite.addTest(TestEnvMgr("test_step_back"))
    suite.addTest(TestEnvMgr("test_store_restore"))
    suite.addTest(TestEnvMgr("test_zero_step_back"))

    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)
