from easydict import EasyDict
import copy


class Node():
    """
    Overview:
        Alpha-Beta-Pruning Search Node.
        https://mathspp.com/blog/minimax-algorithm-and-alpha-beta-pruning
    Arguments:
        env: Class Env, such as
             zoo.board_games.tictactoe.envs.tictactoe_env.TicTacToeEnv,
             zoo.board_games.gomoku.envs.gomoku_env.GomokuEnv
    """

    def __init__(self, board, legal_actions, start_player_index=0, parent=None, prev_action=None, env=None):
        super().__init__()
        self.env = env
        self.board = board
        self.legal_actions = copy.deepcopy(legal_actions)
        self.children = []
        self.parent = parent
        self.prev_action = prev_action
        self.start_player_index = start_player_index
        self.tree_expanded = False

    def __str__(self):
        return f"Tree({', '.join(str(child) for child in self.children)})"

    def expand(self):
        if self.start_player_index == 0:
            next_start_player_index = 1
        else:
            next_start_player_index = 0
        if self.is_terminal_node is False:
            # Ensure self.legal_actions is valid before the loop
            # self.legal_actions = self.env.get_legal_actions(self.board, self.start_player_index)
            while len(self.legal_actions) > 0:
                action = self.legal_actions.pop(0)
                board, legal_actions = self.env.simulate_action_v2(self.board, self.start_player_index, action)
                child_node = Node(
                    board,
                    legal_actions,
                    start_player_index=next_start_player_index,
                    parent=self,
                    prev_action=action,
                    env=self.env
                )
                # print('add one edge')
                self.children.append(child_node)
            self.tree_expanded = True

    @property
    def expanded(self):
        # return len(self.children) > 0
        return self.tree_expanded

    def is_fully_expanded(self):
        return len(self.children) == len(self.legal_actions)

    @property
    def is_terminal_node(self):
        self.env.reset_v2(self.start_player_index, init_state=self.board)  # index
        return self.env.get_done_reward()[0]

    @property
    def value(self):
        """
        def get_done_reward(self):
            Overview:
                To judge game whether over, and get reward
            Returns:
                [game_over, reward]
                if winner = 1  reward = 1
                if winner = 2  reward = -1
                if winner = -1 reward = 0
        """
        self.env.reset_v2(self.start_player_index, init_state=self.board)  # index
        return self.env.get_done_reward()[1]

    @property
    def estimated_value(self):
        return 0

    @property
    def state(self):
        return self.board


def pruning(tree, maximising_player, alpha=float("-inf"), beta=float("+inf"), depth=999, first_level=True):
    if tree.is_terminal_node is True:
        return tree.value
    # TODO(pu): use a limited search depth
    if depth == 0:
        return tree.estimated_value

    # print(ctree)
    if tree.expanded is False:
        tree.expand()
        # print('expand one node!')

    # for debug
    # if (ctree.state == np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1]])).all():
    #     print('p1')
    # if (ctree.state == np.array([[0, 0, 1], [2, 1, 2], [1, 2, 1]])).all():
    #     print('p2')

    val = float("-inf") if maximising_player else float("+inf")
    for subtree in tree.children:
        sub_val = pruning(subtree, not maximising_player, alpha, beta, depth - 1, first_level=False)
        if maximising_player:
            val = max(sub_val, val)
            if val > alpha:
                best_subtree = subtree
                alpha = val
        else:
            val = min(sub_val, val)
            if val < beta:
                best_subtree = subtree
                beta = val
        if beta <= alpha:
            break

    if first_level is True:
        return val, best_subtree
    else:
        return val


class AlphaBetaPruningBot:

    def __init__(self, ENV, cfg, bot_name):
        self.name = bot_name
        self.ENV = ENV
        self.cfg = cfg

    def get_best_action(self, board, player_index, depth=999):
        try:
            simulator_env = copy.deepcopy(self.ENV(EasyDict(self.cfg)))
        except:
            simulator_env = copy.deepcopy(self.ENV)
        simulator_env.reset(start_player_index=player_index, init_state=board)
        root = Node(board, simulator_env.legal_actions, start_player_index=player_index, env=simulator_env)
        if player_index == 0:
            val, best_subtree = pruning(root, True, depth=depth, first_level=True)
        else:
            val, best_subtree = pruning(root, False, depth=depth, first_level=True)

        # print(f'player_index: {player_index}, alpha-beta searched best_action: {best_subtree.prev_action}, its val: {val}')

        return best_subtree.prev_action


if __name__ == "__main__":
    import time
    ##### TicTacToe #####
    from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
    cfg = dict(
        prob_random_agent=0,
        prob_expert_agent=0,
        battle_mode='self_play_mode',
        agent_vs_human=False,
        bot_action_type='alpha_beta_pruning',  # {'v0', 'alpha_beta_pruning'}
        channel_last=False,
        scale=True,
    )
    env = TicTacToeEnv(EasyDict(cfg))
    player_0 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 1')  # player_index = 0, player = 1
    player_1 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 2')  # player_index = 1, player = 2

    ### test from the init empty board ###
    player_index = 0  # player 1 fist
    env.reset()

    ### test from the init specified board ###
    # player_index = 0  # player 1 fist
    # init_state = [[1, 0, 1],
    #               [0, 0, 2],
    #               [2, 0, 1]]
    # env.reset(player_index, init_state)

    state = env.board
    print('-' * 15)
    print(state)

    while not env.get_done_reward()[0]:
        if player_index == 0:
            start = time.time()
            action = player_0.get_best_action(state, player_index=player_index)
            print('player 1 action time: ', time.time() - start)
            player_index = 1
        else:
            start = time.time()
            action = player_1.get_best_action(state, player_index=player_index)
            print('player 2 action time: ', time.time() - start)
            player_index = 0
        env.step(action)
        state = env.board
        print('-' * 15)
        print(state)
        row, col = env.action_to_coord(action)

    ### test from the init empty board ###
    assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1

    ### test from the init specified board ###
    # assert (row == 0, col == 1) or (row == 1, col == 1)
    # assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 1
    """

    ##### Gomoku #####
    from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
    cfg = dict(
        board_size=5,
        prob_random_agent=0,
        prob_expert_agent=0,
        battle_mode='self_play_mode',
        scale=True,
        channel_last=False,
        agent_vs_human=False,
        bot_action_type='alpha_beta_pruning',  # {'v0', 'alpha_beta_pruning'}
        prob_random_action_in_bot=0.,
        check_action_to_connect4_in_bot_v0=False,
    )
    env = GomokuEnv(EasyDict(cfg))
    player_0 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 1')  # player_index = 0, player = 1
    player_1 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 2')  # player_index = 1, player = 2

    ### test from the init empty board ###
    player_index = 0  # player 1 fist
    env.reset()

    ### test from the init specified board ###
    # player_index = 1  # player 2 fist
    # init_state = [[1, 1, 1, 1, 0],
    #               [1, 0, 0, 0, 2],
    #               [0, 0, 2, 0, 2],
    #               [0, 2, 0, 0, 2],
    #               [2, 1, 1, 0, 0], ]
    # # init_state = [[1, 1, 1, 1, 2],
    # #               [1, 1, 2, 1, 2],
    # #               [2, 1, 2, 2, 2],
    # #               [0, 0, 0, 2, 2],
    # #               [2, 1, 1, 1, 0], ]
    # env.reset(player_index, init_state)

    state = env.board
    print('-' * 15)
    print(state)

    while not env.get_done_reward()[0]:
        if player_index == 0:
            start = time.time()
            action = player_0.get_best_action(state, player_index=player_index)
            print('player 1 action time: ', time.time() - start)
            player_index = 1
        else:
            start = time.time()
            action = player_1.get_best_action(state, player_index=player_index)
            print('player 2 action time: ', time.time() - start)
            player_index = 0
        env.step(action)
        state = env.board
        print('-' * 15)
        print(state)

    assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1
    # assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 2
    """

