import random
import sys
import os

import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
import unittest
from MCTS import MCTS
from model import AlphaNet
from Enviroment import EnvManager
from StateActionTracker import StateActionStack
from util import softmax
from generate_combs import generate_all_combs
from import_functions import get_action_space, import_states


class TestMCTS(unittest.TestCase):
    @classmethod
    def setUp(self):
        self.players = 3
        self.round = 2
        self.protocol = "primary_backup"
        self.curr_model = AlphaNet(4, get_action_space(self.protocol))
        self.curr_model.load("/users/huiyujie/self-play-protocol/chkpt/mcts/pb_3p_init")
        self.all_combs = generate_all_combs(players=self.players, num_round=self.round)
        self.dfs_tracker = StateActionStack()

    def assertListsAlmostEqual(self, list1, list2, places=7):
        self.assertEqual(len(list1), len(list2), "Lists are of different lengths")
        for a, b in zip(list1, list2):
            self.assertAlmostEqual(a, b, places=places)

    """
    Expect simulation results (Since we load the same model and fix rando seed, we should always get the same results for test):
    state: [0, 5, 6, 5], policy: [0.734 0.05  0.18  0.036], value_target: [0.415 0.119 0.408 0.057]
    state: [2, 5, 4, 5], policy: [0.722 0.061 0.174 0.043], value_target: [0.37  0.186 0.361 0.083]
    state: [0, 0, 4, 0], policy: [0.032 0.042 0.466 0.46 ], value_target: [0.358 0.182 0.232 0.229]
    """

    def test_mcts(self):
        print(os.path.dirname(__file__))
        random.seed(15)
        env_mgr = EnvManager(3, self.protocol, 2, is_history=False, encode_id=False)
        policy_outputs = []
        StateClass = import_states(self.protocol)
        setting = self.all_combs[256]
        env_mgr.init(setting)
        init_key = env_mgr.construct_crash_key(setting.get_crash_info(0))
        mcts = MCTS(self.curr_model, 3, setting.get_crash_info(0), init_key, "primary_backup")
        while not env_mgr.is_done():
            action, pi_target = mcts.execute(env_mgr, self.dfs_tracker, 1000)
            for i in range(self.players):
                if action[i] == StateClass.Lost.value:
                    continue
                policy_outputs.append(pi_target[i])
            next_round = env_mgr.get_zero_based_round() + 1
            env_mgr.step(action, setting.get_crash_info(next_round))
            if env_mgr.is_done():
                break
            mcts.update_tree(env_mgr, action, setting.get_crash_info(next_round))

        self.assertListsAlmostEqual(policy_outputs[0].tolist(), [0.734, 0.05, 0.18, 0.036], places=3)
        self.assertListsAlmostEqual(policy_outputs[1].tolist(), [0.722, 0.061, 0.174, 0.043], places=3)
        self.assertListsAlmostEqual(policy_outputs[2].tolist(), [0.032, 0.042, 0.466, 0.46], places=3)


if __name__ == "__main__":
    suite = unittest.TestSuite()
    suite.addTest(TestMCTS("test_mcts"))

    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)
