import copy
import math


class ExtensiveGame:

    def __init__(self):
        self.root = None
        self.chance_count = 0
        self.node_count = 0
        self.isets_p1 = {}
        self.isets_p2 = {}
        self.actual_set_p1 = 0
        self.actual_set_p2 = 0
        self.ids = 0
        self.node_names = {}
        self.saved_max_depth = None

    # loads game from file that is in gambit format (expects two player zero sum game and ignores everything before
    # first tree line (first line that starts with c,t or p)
    def load(self, fname):
        with open(fname, 'r') as file:
            line = file.readline()
            while not line.startswith(("p", "c", "t")):
                line = file.readline()
            if line.startswith("c"):
                split = line.split()
                self.root = Node(2, int(split[2]), None, self.ids)
                self.ids += 1
                n_children = int((len(split) - 7) / 2)
                for i in range(n_children):
                    self.root.chance.append(float(split[6 + 2 * i]))
                    self.root.labels.append(split[5 + 2 * i].strip("\""))
            else:
                split = line.split()
                if int(split[2]) == 1:
                    if split[3] in self.isets_p1:
                        next_set = self.isets_p1[split[3]]
                    else:
                        next_set = self.actual_set_p1
                        self.actual_set_p1 += 1
                        self.isets_p1[split[3]] = next_set
                else:
                    if split[3] in self.isets_p2:
                        next_set = self.isets_p2[split[3]]
                    else:
                        next_set = self.actual_set_p2
                        self.actual_set_p2 += 1
                        self.isets_p2[split[3]] = next_set
                self.root = Node(int(split[2]) - 1, next_set, None, self.ids)
                self.ids += 1
                n_children = self.get_children_line(line)
                for i in range(n_children):
                    self.root.labels.append(split[6 + i].strip("\""))
            self.load_children(self.root, n_children, file)

    def get_children_line(self, line):
        line = line[line.find("{") + 1: line.rfind("}")]
        return int(line.count("\"") / 2)

    def load_children(self, node, n_children, file):
        self.node_count += 1
        for i in range(n_children):
            line = file.readline().strip()
            if line.startswith("c"):
                split = line.split()
                next_node = Node(2, int(split[2]), node, self.ids)
                self.ids += 1
                next_children = int((len(split) - 7) / 2)
                for j in range(next_children):
                    next_node.chance.append(float(split[6 + 2 * j]))
                    next_node.labels.append(split[5 + 2 * j].strip("\""))
            elif line.startswith("p"):
                split = line.split()
                if int(split[2]) == 1:
                    if split[3] in self.isets_p1:
                        next_set = self.isets_p1[split[3]]
                    else:
                        next_set = self.actual_set_p1
                        self.actual_set_p1 += 1
                        self.isets_p1[split[3]] = next_set
                else:
                    if split[3] in self.isets_p2:
                        next_set = self.isets_p2[split[3]]
                    else:
                        next_set = self.actual_set_p2
                        self.actual_set_p2 += 1
                        self.isets_p2[split[3]] = next_set
                next_node = Node(int(split[2]) - 1, next_set, node, self.ids)
                self.ids += 1
                next_children = self.get_children_line(line)
                for j in range(next_children):
                    next_node.labels.append(split[6 + j].strip("\""))
            elif line.startswith("t"):
                split = line.split()
                next_node = Node(3, None, node, self.ids, value=float(split[5].strip(",")))
                self.ids += 1
                next_children = 0
            else:
                return
            node.children.append(next_node)
            self.load_children(next_node, next_children, file)

    def load_from_nfg(self, nfg):
        payoffs = nfg.payoffs
        player_one_action_size = len(payoffs[0])
        player_two_action_size = len(payoffs[0][0])
        self.root = Node(0, 0, None, self.getid(), value=None)
        for i in range(player_one_action_size):
            new_node = Node(1, 0, self.root, self.getid())
            self.root.children.append(new_node)
            for j in range(player_two_action_size):
                terminal_node = Node(3, 0, new_node, self.getid(), payoffs[0][i][j])
                new_node.children.append(terminal_node)

    # prints game tree into console
    def print_tree(self):
        self.print_node(self.root, 0)

    def print_node(self, node, depth):
        for i in range(depth):
            print("-", end='')
        node.print()
        if not node.player == 3:
            for next_node in node.children:
                self.print_node(next_node, depth + 1)

    # saves game to file in gambit format
    def save_to_file(self, fname):
        self.node_count = 0
        self.chance_count = 0
        with open(fname, 'w') as file:
            file.write("EFG 2 R \"My tree\" { \"Player 1\" \"Player 2\" }\n")
            self.save_node_to_file(self.root, file, 0)

    def save_node_to_file(self, node, file, level):
        for i in range(level):
            file.write(" ")
        if node.player == 3:
            if node.value in self.node_names:
                name = self.node_names[node.value]
            else:
                name = self.node_count
                self.node_count += 1
            file.write(
                "t \"\" " + str(name - 1) + " \"\" { " + str(node.value) + " " + str(-node.value) + " }\n")
        elif node.player == 2:
            file.write("c \"\" " + str(self.chance_count) + " \"\" { ")
            self.chance_count += 1
            for i in range(len(node.children)):
                file.write("\"" + (str(i) if len(node.labels) == 0 else node.labels[i]) + "\" " + str(node.chance[i]) + " ")
            file.write("} 0\n")
        else:
            file.write("p \"\" " + str(node.player + 1) + " " + str(node.i_set) + " \"\" {")
            for i in range(len(node.children)):
                file.write(" \"" + (str(i) if len(node.labels) == 0 else node.labels[i]) + "\"")
            file.write(" } 0\n")
        if node.children is not None:
            for next_node in node.children:
                self.save_node_to_file(next_node, file, level + 1)

    # consistently gives increasing is's, for use when creating game from outside
    def getid(self):
        self.ids += 1
        return self.ids - 1

    def get_actual_set(self, player):
        if player == 0:
            self.actual_set_p1 += 1
            return self.actual_set_p1 - 1
        if player == 1:
            self.actual_set_p2 += 1
            return self.actual_set_p2 - 1
        else:
            raise ValueError("Player must be 0 or 1")

    def switch_players(self):
        self.switch_players_step(self.root)

    def switch_players_step(self, node):
        if node.player == 3:
            node.value = -node.value
        if node.player < 2:
            node.player = 1 - node.player
        for child in node.children:
            self.switch_players_step(child)

    def get_terminal_nodes(self):
        terminal_nodes = []
        self._get_terminal_nodes_step(self.root, terminal_nodes)
        return terminal_nodes

    def _get_terminal_nodes_step(self, node, terminal_nodes):
        if node.player == 3:
            terminal_nodes.append(node)
        for child in node.children:
            self._get_terminal_nodes_step(child, terminal_nodes)

    def get_sequences(self, player):
        sequences = set()
        self._get_sequences(self.root, sequences, tuple(), player)
        dict_sequences = {}
        for i, sequence in enumerate(sequences):
            dict_sequences[i] = sequence
        return dict_sequences

    def _get_sequences(self, node, sequences, sequence, player):
        current_player = node.player
        if current_player == 3:
            sequences.add(sequence)
        if current_player == player:
            for i, child in enumerate(node.children):
                new_sequence = sequence + ((node.i_set, i),)
                self._get_sequences(child, sequences, new_sequence, player)
        else:
            for child in node.children:
                self._get_sequences(child, sequences, sequence, player)

    def fix_strategy_until_depth(self, player, strategy, fix_depth):
        self._fix_strategy_until_depth_step(player, strategy, fix_depth, self.root, 0)

    def fix_strategy_until_depth_from_node(self, player, strategy, fix_depth, node):
        self._fix_strategy_until_depth_step(player, strategy, fix_depth, node, 0)

    def _fix_strategy_until_depth_step(self, player, strategy, fix_depth, node, depth):
        if node.is_terminal():
            return
        if fix_depth == depth:
            return
        if node.player == player:
            node.player = 2
            node.chance = strategy[node.i_set]
        for child in node.children:
            self._fix_strategy_until_depth_step(player, strategy, fix_depth, child, depth + 1)

    def max_depth(self):
        if self.saved_max_depth is None:
            self.saved_max_depth = self._max_depth(self.root, 0)
        return self.saved_max_depth

    def _max_depth(self, node, depth):
        if node.is_terminal():
            return depth
        else:
            ret_depth = 0
            for child in node.children:
                ret_depth = max(self._max_depth(child, depth + 1), ret_depth)
            return ret_depth

    def print_strategy_at_depth(self, strategy, depth, player):
        self._print_strategy_at_depth(strategy, depth, player, self.root, 0, set())

    def _print_strategy_at_depth(self, strategy, depth, player, node, current_depth, printed_i_sets):
        if node.is_terminal():
            return
        if depth == current_depth:
            if node.player == player and node.i_set not in printed_i_sets:
                print(node.i_set, strategy[node.i_set])
                printed_i_sets.add(node.i_set)
            return
        for child in node.children:
            self._print_strategy_at_depth(strategy, depth, player, child, current_depth + 1, printed_i_sets)

    def infoset_to_observation_history_only_for_poker(self):
        result_dicts = [{}, {}]
        for child, action in zip(self.root.children, self.root.labels):
            self._infoset_to_observation_history_only_for_poker(result_dicts, child, [action[0], action[1]])
        return result_dicts

    def _infoset_to_observation_history_only_for_poker(self, result_dicts, node, histories):
        if node.is_terminal():
            return
        if node.player == 0 or node.player == 1:
            player = node.player
            if node.i_set in result_dicts[player]:
                assert result_dicts[player][node.i_set] == histories[player]
            else:
                result_dicts[player][node.i_set] = histories[player]
        for child, action in zip(node.children, node.labels):
            new_histories = copy.deepcopy(histories)
            new_histories[0] += action
            new_histories[1] += action
            self._infoset_to_observation_history_only_for_poker(result_dicts, child, new_histories)

    def infoset_to_action_labels(self):
        result_dicts = [{}, {}]
        for child in self.root.children:
            self._infoset_to_action_labels(result_dicts, child)
        return result_dicts

    def _infoset_to_action_labels(self, result_dicts, node):
        if node.is_terminal():
            return
        if node.player == 0 or node.player == 1:
            player = node.player
            if node.i_set in result_dicts[player]:
                assert result_dicts[player][node.i_set] == node.labels
            else:
                result_dicts[player][node.i_set] = node.labels
        for child in node.children:
            self._infoset_to_action_labels(result_dicts, child)


class Node:

    def __init__(self, player, i_set, parent, id_val, value=None):
        self.id_val = id_val
        self.player = player
        self.i_set = i_set
        self.parent = parent
        self.children = []
        self.chance = []
        self.value = value
        self.labels = []

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        ret = str(self.player)
        if self.player < 2:
            ret += " " + str(self.i_set)
        elif self.player == 3:
            ret += " Val:" + str(self.value)
        return "N(" + ret + " " + str(self.id_val) + ")"

    def print(self):
        print(f"id {self.id_val}", end=" ")
        print(self.player, end=" ")
        if self.player < 2:
            print(self.i_set)
        elif self.player == 3:
            print("Val:", self.value)
        else:
            print()

    def __eq__(self, other):
        return self.player == other.player and self.id_val == other.id_val and self.i_set == other.i_set

    def __hash__(self):
        return hash((self.player, self.id_val, self.i_set))

    def is_terminal(self):
        return self.player == 3

    def is_chance(self):
        return self.player == 2
