import itertools
from typing import Dict
import numpy as np
import random
from Enviroment import EnvManager
from StateActionTracker import StateActionStack
from config import EPSILON
from generate_combs import CrashInfo, _get_receiver_row_patterns
from import_functions import import_states
from primary_backup.State import State as PBState
from atomic_commit.State import State as ACState
from util import get_avg, get_min, masked_softmax, softmax
from utils import hash_observation, log


class MCTSNode:
    """
    Represent one node in Monte Carlo Tree.
    Each MCTSNode has its unique representation.
    In our case, it is the observed states of all nodes even one node cannot know other nodes state during inference
    """

    def __init__(self, action, p, noise, players, crashed_nodes_map: Dict[str, "CrashNode"]):
        self.n = 0  # visit count
        self.p = p  # prior probablity (predicted by NN)
        self.dirichlet_noise = noise  # dirichlet noise for exploration
        self.w = np.array([0.0] * players)  # accumulated win value
        self.action = action  # action taken at this Node
        self.crashed_nodes_map = crashed_nodes_map  # Track crashed nodes

    def __str__(self):
        ret = f"action: {self.action}, n: {self.n}, p: {self.p}, w: {self.w}"
        return ret


class CrashNode:
    """
    Insert between 2 MCTSNode to track crashed nodes
    """

    def __init__(self, parent: MCTSNode, crash_info: CrashInfo, crash_key: str):
        self.n = 0
        self.crash_info = crash_info
        self.crash_key = crash_key
        self.score = 0  # Track the score of this crash scenario
        self.parent = parent
        self.childs: list[MCTSNode] = []  # should be the same as parent.childs

    def __str__(self):
        ret = f"crashed nodes: {self.crash_key}, crashed cnt: {self.n}, score: {self.score}, average score: {self.score / self.n:.3f}"
        return ret


class MCTS:
    """
    Implement Monte Carlo Tree Search
    """

    def __init__(self, model, players, init_crash_info, init_key, protocol, num_round):
        self.root = MCTSNode(None, 0, 0, players, dict())
        self.root.crashed_nodes_map[init_key] = CrashNode(self.root, init_crash_info, init_key)
        self.sub_root = self.root.crashed_nodes_map[init_key]
        self.model = model
        self.unfix_list = []
        self.StateClass = import_states(protocol)
        self.num_round = num_round

    def compute_mcts_wp(self, env_mgr, actions):
        """
        actions: [[a1, a2, a3, ...], [], [], ...]
        values: win/q value of each action based current state of each env
        calculated action values by taking average of all states
        policies: action distribution based on current state of each env
        """
        policy_output = []
        for i in range(env_mgr.players):
            policy = self.model.predict_policy(env_mgr.get_states(i))
            policy = np.squeeze(policy)
            softmaxed_policy = softmax(policy)
            policy_output.append(softmaxed_policy)

        policy_output = np.array(policy_output)

        policies = []
        for action in actions:
            p = []
            for i in range(env_mgr.players):
                if action[i] == self.StateClass.Lost.value:
                    p.append(0)
                else:
                    p.append(policy_output[i][action[i]])

            policies.append(np.array(p))
        return policies

    # Update Q value and visit count of all visited nodes
    def backprop(self, visited_nodes, rewards, players):
        for node in visited_nodes:
            node.n += 1
            if isinstance(node, CrashNode):
                node.score -= min(rewards)  # Minimax
            elif isinstance(node, MCTSNode):
                if node.action is None:  # The root node
                    continue
                for i in range(players):
                    if node.action[i] != self.StateClass.Lost.value:
                        node.w[i] += rewards[i]

    def get_ucb(self, parent: CrashNode, add_noise, ucb_c):
        # sum_n = np.sqrt(parent.parent.n)
        pnw = np.array([[get_min(child.p), child.n, get_min(child.w), child.dirichlet_noise] for child in parent.childs])
        sum_n = np.sqrt(np.sum(pnw[:, 1]))
        p = pnw[:, 0]  # priors
        n = pnw[:, 1]  # visit count
        w = pnw[:, 2]  # accumulated win
        noise = pnw[:, 3]  # dirichlet noise
        if add_noise:  # add noise for root's priors
            p = 0.75 * p + 0.25 * noise
        q = w / (n + np.finfo(np.float32).eps)  # value based on win values (exploitation)
        u = p * (sum_n / (1 + n))  # value based on prior * visit count (exploration)

        val = q + u * ucb_c  # exploitation + exploration
        return val

    # Get policy target for training policy network
    # Calculated by summing up all visited count of each action and normalize it
    # Return: [s1:[p1, p2, p3, ...], s2:[], [], ...]
    def get_policy_train_dnn(self, env_mgr, actions, pis):
        if env_mgr.get_zero_based_round() + 1 == env_mgr.num_rounds:
            active_nodes = set(env_mgr.get_alive_nodes()).union(env_mgr.newly_crash)
        else:
            active_nodes = env_mgr.get_alive_nodes()
        target = np.zeros((env_mgr.players, env_mgr.action_space))
        for action_set, pi in zip(actions, pis):
            for i in active_nodes:
                action_value = action_set[i]
                target[i][action_value] += pi
        row_sums = np.sum(target, axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1
        target = target / row_sums
        return target

    def select_move(self, childs, is_inference=False):
        n = np.array([child.n for child in childs])
        log("MCTS: select_move", 1, f"n: {n}")
        pi = n / np.sum(n)
        if is_inference:
            idx = np.argmax(pi)
        else:
            idx = np.random.choice(np.arange(len(pi), dtype=int), 1, p=pi)[0]
        return idx, n

    # TODO: This function is duplicated with the one in generate_combs.py.
    #       Need to refactor
    def create_crash_info(self, alive_nodes, current_round):
        """
        Create all possible crash info in this round
        """
        all_crash_sets = itertools.chain.from_iterable(itertools.combinations(alive_nodes, r) for r in range(len(alive_nodes)))
        all_crash_infos = []
        for crash_tuple in all_crash_sets:
            crash = sorted(list(crash_tuple))
            new_alive = [n for n in alive_nodes if n not in crash]
            if current_round == self.num_round - 1:
                receivers = sorted(new_alive + crash)
            else:
                receivers = sorted(new_alive)
            if len(receivers) == 0:
                receive_mat = np.empty((len(receivers), 0), dtype=bool)
                all_crash_infos.append(CrashInfo(crash, new_alive, receive_mat, current_round))
                continue

            row_options = [
                _get_receiver_row_patterns(crash, rcv) for rcv in receivers
            ]  # each element: list[np.ndarray bool shape(len(crash))]

            for row_choice in itertools.product(*row_options):
                receive_mat = np.stack(row_choice, axis=0)  # (rows, cols)
                all_crash_infos.append(CrashInfo(crash, new_alive, receive_mat, current_round))
        return all_crash_infos

    def expand_crash_node(self, env_mgr: EnvManager, node: MCTSNode):
        alive_nodes = env_mgr.get_alive_nodes()
        # We are creating CrashNode for the next round before calling step_envs()
        # So pass env_mgr.get_zero_based_round() + 1
        all_crash_infos = self.create_crash_info(alive_nodes, env_mgr.get_zero_based_round() + 1)
        for crash_info in all_crash_infos:
            key = env_mgr.construct_crash_key(crash_info)
            assert key not in node.crashed_nodes_map
            crash_node = CrashNode(node, crash_info, key)
            node.crashed_nodes_map[key] = crash_node

    def sample_crash_node(self, env_mgr: EnvManager, node: MCTSNode) -> CrashNode:
        """
        Sample a CrashNode from current MCTSNode
        """
        # If this is the last round, we don't need to sample a new crash node
        if env_mgr.get_zero_based_round() + 1 == env_mgr.num_rounds:
            return None
        if not node.crashed_nodes_map:
            self.expand_crash_node(env_mgr, node)
        # Compute total parent visit count
        N_parent = sum(c.n for c in node.crashed_nodes_map.values())
        best_crashnode = None
        best_score = -float("inf")
        for crash_node in node.crashed_nodes_map.values():
            if crash_node.n == 0:
                return crash_node  # Return the first unvisited CrashNode
            # UCB value for each CrashNode
            mean = crash_node.score / crash_node.n
            ucb = mean + 0 * np.sqrt(np.log(N_parent) / crash_node.n)
            if ucb > best_score:
                best_score = ucb
                best_crashnode = crash_node

        return best_crashnode

    def policy(self, env_mgr: EnvManager, dfs_tracker: StateActionStack, visited_nodes):
        func = "MCTS: policy"
        node = self.sub_root  # Start from CrashNode because the crash state is already determined
        visited_nodes.append(node.parent)  # Track visited MCTSNode during search
        visited_nodes.append(node)  # Track visited CrashNode during search
        while not env_mgr.is_done():
            # env_mgr.get_current_states()
            if len(node.childs) == 0:
                actions = env_mgr.get_actions(env_mgr.get_zero_based_round())  # get possible actions
                actions = self.filter_action(env_mgr, dfs_tracker, actions)
                log(func, 2, actions)
                ps = self.compute_mcts_wp(env_mgr, actions)  # compute policy and value
                dirichlet_noise = np.random.dirichlet([0.1] * len(actions))
                node.childs = [MCTSNode(a, ps[i], dirichlet_noise[i], env_mgr.players, {}) for i, a in enumerate(actions)]

            """ Select MCTSNode (select action) """
            if random.random() < EPSILON:
                idx = random.choice(range(len(node.childs)))
            else:
                # select the child node with the highest UCB value
                ucb = self.get_ucb(node, False, ucb_c=5)
                idx = np.argmax(ucb)
            node = node.childs[idx]  # Get MCTSNode from CrahsNode
            visited_nodes.append(node)  # Append MCTSNode for backpropagation

            """ Select CrashNode (select environment) """
            crash_node = self.sample_crash_node(env_mgr, node)  # Create new CrashNode if needed
            env_mgr.step(
                list(node.action), crash_node.crash_info if crash_node is not None else None
            )  # uncertainty is injected here
            # We don't need to insert another crashnode
            if env_mgr.is_done():
                break
            visited_nodes.append(crash_node)  # Append CrashNode for backpropagation
            node = crash_node

        rewards = env_mgr.get_rewards()
        return node, rewards

    # Filter action set by checking those fixed actions
    def filter_action(self, env_mgr: EnvManager, dfs_tracker: StateActionStack, actions):
        filtered_action = []
        simulated_actions = []  # action that is already simulated that should be fixed
        for i in range(env_mgr.players):
            if actions[0][i] == self.StateClass.Lost.value:
                simulated_actions.append(self.StateClass.Lost.value)
            else:
                simulated_actions.append(
                    dfs_tracker.get_fixing_action(env_mgr.get_states(i))
                )  # Append None if this state is not fixed

        for action in actions:
            ok_to_add = True
            for i in range(env_mgr.players):
                if simulated_actions[i] is None:
                    continue
                else:
                    if action[i] != simulated_actions[i]:
                        ok_to_add = False
                        break
            if ok_to_add:
                filtered_action.append(action)
        return filtered_action

    def execute(self, env_mgr: EnvManager, dfs_tracker):
        func = "MCTS: execute"
        # Determine number of iterations based on number of rounds and players
        iterations = env_mgr.num_rounds * env_mgr.players * 1000
        for i in range(iterations):
            log(func, 2, f"####Iteration {i}#####")
            policy_nodes = []  # visited node during current iteration
            env_mgr.store()
            _, rewards = self.policy(env_mgr, dfs_tracker, policy_nodes)
            steps = sum(1 for node in policy_nodes if isinstance(node, MCTSNode)) - 1
            env_mgr.step_back(steps)
            env_mgr.restore()
            self.backprop(policy_nodes, rewards, env_mgr.players)
            for node in policy_nodes:
                log(func, 2, node)

        # select move, and return visited counts of all child nodes
        idx, ns = self.select_move(self.sub_root.childs, True)
        policy_target = self.get_policy_train_dnn(env_mgr, [child.action for child in self.sub_root.childs], ns)

        return self.sub_root.childs[idx].action, policy_target

    def update_tree(self, env_mgr: EnvManager, action, crash_info: CrashInfo = None):
        """
        self.root is MCTSNode
        self.sub_root is CrashNode
        This function updates the tree based on the action taken by the environment after calling step_envs()
        """
        # Get MCTSNode based on action
        if len(self.sub_root.childs) == 0:
            log("MCTS: update_tree", 1, "Action not found in tree, expand the tree")
            self.sub_root.childs = [MCTSNode(action, 1, env_mgr.players, {})]
            self.sub_root = self.sub_root.childs[0]
        else:
            for child in self.sub_root.childs:
                if child.action == action:
                    log("MCTS: update_tree", 1, "Action found in tree")
                    self.sub_root = child

        key = env_mgr.construct_crash_key(crash_info)
        if key in self.sub_root.crashed_nodes_map:
            log("MCTS: update_tree", 1, "Find next CrashNode")
            self.sub_root = self.sub_root.crashed_nodes_map[key]
            return
        else:
            log("MCTS: update_tree", 1, "Create new CrashNode")
            print(f"Cannot find key: {key}")
            crash_node = CrashNode(self.sub_root, key)
            self.sub_root.crashed_nodes_map[key] = crash_node
            self.sub_root = crash_node
            return
        log("MCTS: update_tree", "Something wrong")
        exit(-1)

    def visualize_tree(self, root, level=0):
        if isinstance(root, MCTSNode):
            if root.n > 0:
                print("\t" * level, root)
            for crash_child in root.crashed_nodes_map.values():
                self.visualize_tree(crash_child, level + 1)
        elif isinstance(root, CrashNode):
            if root.n > 0:
                print("\t" * level, root)
            for child in root.childs:
                self.visualize_tree(child, level + 1)

    def check_list(self, lst, state_value):
        # filter out LocalOne (6) or LocalZero (5) first
        return all(item == state_value for item in lst)

    def get_filter_idx(self, env_mgr):
        idx_to_restrict = []
        for i in range(env_mgr.players):
            if i not in env_mgr.get_alive_nodes():
                continue
            msg = env_mgr.get_received_msg(i)
            if self.check_list(msg, ACState.LocalCommit.value):
                idx_to_restrict.append(i)
        return idx_to_restrict, ACState.Commit.value
