import BestNEStatic
import GenerateRnRGame
from CFR import CFR
from ExtensiveGame import ExtensiveGame
from ExtensiveGame import Node
import SequenceNash
import copy
import numpy as np
from Combination import Combination

from GenerateRnRGameNotFixed import GenerateRNRGameNotFixed


def print_bold(text):
    print('\033[1m' + text + '\033[0m')


class RNR_decomposition:
    def __init__(self):
        self.fixed_tree_subgame_roots = {}
        self.fixed_tree_subgame_nodes = [{}, {}]
        self.fixed_tree_subgame_infosets = [{}, {}]

        self.fixed_tree_trunk_nodes = [[], []]
        self.fixed_tree_trunk_infosets = [[], []]

        self.free_tree_subgame_roots = {}
        self.free_tree_subgame_nodes = [{}, {}]
        self.free_tree_subgame_infosets = [{}, {}]

        self.free_tree_trunk_nodes = [[], []]
        self.free_tree_trunk_infosets = [[], []]

    def add_to_subgame_roots(self, sequence, free, node):
        sequence = tuple(sequence)
        if free:
            if sequence not in self.free_tree_subgame_roots:
                self.free_tree_subgame_roots[sequence] = []
            self.free_tree_subgame_roots[sequence].append(node)
        else:
            if sequence not in self.fixed_tree_subgame_roots:
                self.fixed_tree_subgame_roots[sequence] = []
            self.fixed_tree_subgame_roots[sequence].append(node)

    def add_to_subgame_nodes(self, player, sequence, free, node):
        sequence = tuple(sequence)
        if free:
            if sequence not in self.free_tree_subgame_nodes[player]:
                self.free_tree_subgame_nodes[player][sequence] = []
            self.free_tree_subgame_nodes[player][sequence].append(node)
        else:
            if sequence not in self.fixed_tree_subgame_nodes[player]:
                self.fixed_tree_subgame_nodes[player][sequence] = []
            self.fixed_tree_subgame_nodes[player][sequence].append(node)

    def add_to_subgame_infosets(self, player, sequence, free, infoset):
        sequence = tuple(sequence)
        if free:
            if sequence not in self.free_tree_subgame_infosets[player]:
                self.free_tree_subgame_infosets[player][sequence] = []
            self.free_tree_subgame_infosets[player][sequence].append(infoset)
        else:
            if sequence not in self.fixed_tree_subgame_infosets[player]:
                self.fixed_tree_subgame_infosets[player][sequence] = []
            self.fixed_tree_subgame_infosets[player][sequence].append(infoset)

    def add_to_trunk_nodes(self, player, free, node):
        if free:
            self.free_tree_trunk_nodes[player].append(node)
        else:
            self.fixed_tree_trunk_nodes[player].append(node)

    def add_to_trunk_infosets(self, player, free, infoset):
        if free:
            self.free_tree_trunk_infosets[player].append(infoset)
        else:
            self.fixed_tree_trunk_infosets[player].append(infoset)

    def set_trunk_infosets(self, player, free, infosets):
        if free:
            self.free_tree_trunk_infosets[player] = infosets
        else:
            self.fixed_tree_trunk_infosets[player] = infosets

    def set_subgame_infosets(self, player, free, sequence, infosets):
        sequence = tuple(sequence)
        if free:
            self.free_tree_subgame_infosets[player][sequence] = infosets
        else:
            self.fixed_tree_subgame_infosets[player][sequence] = infosets

    def get_subgame_roots(self, free):
        if free:
            return self.free_tree_subgame_roots
        else:
            return self.fixed_tree_subgame_roots

    def get_subgame_nodes(self, free):
        if free:
            return self.free_tree_subgame_nodes
        else:
            return self.fixed_tree_subgame_nodes

    def get_subgame_infosets(self, free):
        if free:
            return self.free_tree_subgame_infosets
        else:
            return self.fixed_tree_subgame_infosets

    def get_trunk_nodes(self, free):
        if free:
            return self.free_tree_trunk_nodes
        else:
            return self.fixed_tree_trunk_nodes

    def get_trunk_infosets(self, free):
        if free:
            return self.free_tree_trunk_infosets
        else:
            return self.fixed_tree_trunk_infosets


def get_all_info_from_game_by_lp(lp_solver):
    optimal_policy = lp_solver.strategy_in_cfr_format()
    lp_solver.solve(1)
    optimal_policy[1] = lp_solver.strategy_in_cfr_format()[1]

    cfr = CFR(None)
    cfr.game = lp_solver.game
    cfr.initialize()
    cfr.strategy = optimal_policy
    cfr.compute_counterfactual_values()
    return cfr.brs_to_cfbrs_and_cfvs(optimal_policy)


def get_infosets_from_nodes(decomposition):
    for player in [0, 1]:
        for free in [True, False]:
            infostate_collector = set()
            for node in decomposition.get_trunk_nodes(free)[player]:
                infostate_collector.add(node.i_set)
            decomposition.set_trunk_infosets(player, free, list(infostate_collector))
            for sequence in decomposition.get_subgame_nodes(free)[player]:
                infostate_collector = set()
                for node in decomposition.get_subgame_nodes(free)[player][sequence]:
                    infostate_collector.add(node.i_set)
                decomposition.set_subgame_infosets(player, free, sequence, list(infostate_collector))


def rnr_leduc_subgame_roots_and_nodes(game):
    decomposition = RNR_decomposition()

    def decompose_leduc_step(node, sequence, trunk, free):
        if len(sequence) > 0 and node.player == 2:
            for label, child in zip(node.labels, node.children):
                new_sequence = copy.deepcopy(sequence)
                new_sequence.append(label)
                decomposition.add_to_subgame_roots(new_sequence, free, child)
                decompose_leduc_step(child, new_sequence, False, free)
        elif node.player == 2:
            for child in node.children:
                decompose_leduc_step(child, sequence, trunk, free)
        elif node.player == 3:
            return
        else:
            if trunk:
                decomposition.add_to_trunk_nodes(node.player, free, node)
            else:
                decomposition.add_to_subgame_nodes(node.player, sequence, free, node)
            for label, child in zip(node.labels, node.children):
                new_sequence = copy.deepcopy(sequence)
                if trunk:
                    new_sequence.append(label)
                decompose_leduc_step(child, new_sequence, trunk, free)

    decompose_leduc_step(game.root.children[1], [], True, True)
    decompose_leduc_step(game.root.children[0], [], True, False)

    get_infosets_from_nodes(decomposition)

    return decomposition


# ONLY WORKS FOR ROUND GAMES WHERE ALL THE INFORMATION IS RELEASED AFTER EACH ROUND
def step_subgame_roots_and_nodes(game, step):
    decomposition = RNR_decomposition()

    def decompose_step_step(node, sequence, trunk, free, depth):
        if step == depth:
            trunk = False
            decomposition.add_to_subgame_roots((node.i_set,), free, node)
            sequence = [node.i_set]
        if node.player == 2:
            for child in node.children:
                decompose_step_step(child, sequence, trunk, free, depth + 1)
        elif node.player == 3:
            return
        else:
            if trunk:
                decomposition.add_to_trunk_nodes(node.player, free, node)
            else:
                decomposition.add_to_subgame_nodes(node.player, sequence, free, node)
            for label, child in zip(node.labels, node.children):
                decompose_step_step(child, sequence, trunk, free, depth + 1)

    decompose_step_step(game.root.children[1], [], True, True, 0)
    decompose_step_step(game.root.children[0], [], True, False, 0)

    get_infosets_from_nodes(decomposition)

    return decomposition


def counterexample_subgame_roots_and_nodes(game):
    decomposition = RNR_decomposition()

    def decompose_counterexample_step(node, sequence, trunk, free, depth):
        if depth == 1:
            trunk = False
            decomposition.add_to_subgame_roots((0,), free, node)
            sequence = [0]
        if node.player == 2:
            for child in node.children:
                decompose_counterexample_step(child, sequence, trunk, free, depth + 1)
        elif node.player == 3:
            return
        else:
            if trunk:
                decomposition.add_to_trunk_nodes(node.player, free, node)
            else:
                decomposition.add_to_subgame_nodes(node.player, sequence, free, node)
            for label, child in zip(node.labels, node.children):
                decompose_counterexample_step(child, sequence, trunk, free, depth + 1)

    decompose_counterexample_step(game.root.children[1], [], True, True, 0)
    decompose_counterexample_step(game.root.children[0], [], True, False, 0)

    get_infosets_from_nodes(decomposition)

    return decomposition


def create_resolving_gadget(free_root, cfvs, game, children, reaches):
    child_iset_to_new_iset = {}
    for node in children:
        if node.children[0].i_set not in child_iset_to_new_iset:
            child_iset_to_new_iset[node.children[0].i_set] = game.get_actual_set(1)
    nodes = [Node(1, child_iset_to_new_iset[node.children[0].i_set], free_root, game.getid()) for node in children]
    free_root.children = nodes
    for child in children:
        free_root.chance.append(reaches[child][1] * reaches[child][2])
    for node, parent in zip(children, nodes):
        terminal_node = Node(3, 0, parent, game.getid())
        parent.children = [node, terminal_node]
        node.parent = parent
        parent.labels = ["F", "T"]
    infoset_values = {}
    infoset_reaches = {}
    for node in nodes:
        if node.i_set not in infoset_values:
            infoset_values[node.i_set] = []
            infoset_reaches[node.i_set] = []
        infoset_values[node.i_set].append(cfvs[node.children[0]][1])
        infoset_reaches[node.i_set].append(reaches[node.children[0]][1] * reaches[node.children[0]][2])
    for key in infoset_values:
        if np.sum(infoset_reaches[key]) != 0:
            infoset_reaches[key] = infoset_reaches[key] / np.sum(infoset_reaches[key])
    for node in nodes:
        node.children[1].value = -np.dot(infoset_values[node.i_set], infoset_reaches[node.i_set])


def adjust_terminals(node, value):
    if node.is_terminal():
        node.value += value
    else:
        for child in node.children:
            adjust_terminals(child, value)


def create_max_margin_gadget(free_root, cfvs, game, children, reaches):
    child_iset_to_node = {}
    nodes = []
    i = 0
    for node in children:
        if node.children[0].player == 3:
            nodes.append(Node(2, 0, free_root, game.getid()))
            node.children[0].i_set = i
            child_iset_to_node[i] = nodes[-1]
            i += 1
        if node.children[0].i_set not in child_iset_to_node:
            nodes.append(Node(2, 0, free_root, game.getid()))
            child_iset_to_node[node.children[0].i_set] = nodes[-1]
    free_root.children = nodes
    for child in children:
        node = child_iset_to_node[child.children[0].i_set]
        node.children.append(child)
        child.parent = node
        node.chance.append((reaches[child][1] * reaches[child][2]))
    infoset_values = {}
    infoset_reaches = {}

    for child in children:
        node = child_iset_to_node[child.children[0].i_set]
        if node.id_val not in infoset_values:
            infoset_values[node.id_val] = []
            infoset_reaches[node.id_val] = []
        infoset_values[node.id_val].append(cfvs[child][1])
        infoset_reaches[node.id_val].append(reaches[child][1] * reaches[child][2])
    for key in infoset_values:
        if np.sum(infoset_reaches[key]) != 0:
            infoset_reaches[key] = infoset_reaches[key] / np.sum(infoset_reaches[key])
    for node in nodes:
        adjust_terminals(node, np.dot(infoset_values[node.id_val], infoset_reaches[node.id_val]))


def get_reaches_from_strategy(strategy, game, player):
    reaches = {}
    _get_reaches_from_strategy(strategy, game.root.children[0], 1., reaches, player)
    return reaches


def _get_reaches_from_strategy(strategy, node, reach, reaches, player):
    reaches[node] = reach
    if node.is_terminal():
        return
    elif node.player == player:
        for i, child in enumerate(node.children):
            _get_reaches_from_strategy(strategy, child, reach * strategy[node.i_set][i], reaches, player)
    else:
        for child in node.children:
            _get_reaches_from_strategy(strategy, child, reach, reaches, player)


def cd_rnr_one_step(file_name, opponent_strategy, p, gadget="full", decomposition_type="leduc", verbose=True):
    # Computing equilibirum for comparison and value of the game to compute exploitabilities and gains
    lp_solver = SequenceNash.SequenceNash(file_name)
    lp_solver.solve(0)
    equilibirum_strategy = lp_solver.strategy_in_cfr_format()

    value_of_the_game = lp_solver.solve(1)
    equilibirum_strategy[1] = lp_solver.strategy_in_cfr_format()[1]

    # Creation of the rnr game and decomposition of the game
    rnr_generator = GenerateRNRGameNotFixed(file_name)
    temp_fname = f"data/temp_{decomposition_type}_{gadget}.efg"
    rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
    rnr_game = ExtensiveGame()
    rnr_game.load(temp_fname)
    if decomposition_type == "leduc":
        decomposition = rnr_leduc_subgame_roots_and_nodes(rnr_game)
    elif decomposition_type == "steps":
        decomposition = step_subgame_roots_and_nodes(rnr_game, 2)
    elif decomposition_type == "counterexample":
        decomposition = counterexample_subgame_roots_and_nodes(rnr_game)
    else:
        assert False, "Invalid decomposition type."

    # Computing the trunk strategy
    resulting_strategy = {}

    if gadget != "ses" and gadget != "exp-strat":
        for node in decomposition.get_trunk_nodes(False)[1]:
            node.player = 2
            node.chance = opponent_strategy[node.i_set]

    lp_solver.game = rnr_game
    lp_solver.solve(0)
    nash_policy = lp_solver.strategy_in_cfr_format()[0]

    if gadget == "comb":
        lp_solver.game.root.chance = [1, 0]
        lp_solver.solve(0)
        cdbr_policy = lp_solver.strategy_in_cfr_format()[0]
        cdbr = {}
        for infoset in decomposition.get_trunk_infosets(True)[0]:
            cdbr[infoset] = cdbr_policy[infoset]
        lp_solver.game.root.chance = [0, 1]
        lp_solver.solve(0)
        full_nash = lp_solver.strategy_in_cfr_format()[0]

    for infoset in decomposition.get_trunk_infosets(True)[0]:
        resulting_strategy[infoset] = nash_policy[infoset]

    # Computing subgame strategies
    cfvs, reaches = get_all_info_from_game_by_lp(lp_solver)
    if gadget == "full":
        for free in [True, False]:
            for node in decomposition.get_trunk_nodes(free)[0]:
                node.player = 2
                node.chance = resulting_strategy[node.i_set]
        for sequence in decomposition.get_subgame_roots(True):
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            # fixed_root = Node(2, 0, lp_solver.game.root, lp_solver.game.getid())
            # fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            # for child in fixed_root.children:
            #     fixed_root.chance.append(np.prod(reaches[child]))
            # lp_solver.game.root.children[0] = fixed_root
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "resolving":
        game = lp_solver.game
        for sequence in decomposition.get_subgame_roots(True):
            game.root = Node(2, 0, None, game.getid())
            root = lp_solver.game.root
            fixed_root = Node(2, 0, root, game.getid())
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(np.prod(reaches[child]))
            free_root = Node(2, 0, root, game.getid())
            create_resolving_gadget(free_root, cfvs, game, decomposition.get_subgame_roots(True)[sequence], reaches)
            root.children = [fixed_root, free_root]
            root.chance = [p, 1 - p]
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "max-margin":
        game = lp_solver.game
        for sequence in decomposition.get_subgame_roots(True):
            game.root = Node(2, 0, None, game.getid())
            root = lp_solver.game.root
            fixed_root = Node(2, 0, root, game.getid())
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(np.prod(reaches[child]))
            free_root = Node(1, game.get_actual_set(1), root, game.getid())
            create_max_margin_gadget(free_root, cfvs, game, decomposition.get_subgame_roots(True)[sequence], reaches)
            root.children = [fixed_root, free_root]
            root.chance = [p, 1 - p]
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "ses":
        game = lp_solver.game
        fixed_reaches = get_reaches_from_strategy(opponent_strategy, game, 1)
        for sequence in decomposition.get_subgame_roots(True):
            game.root = Node(2, 0, None, game.getid())
            root = lp_solver.game.root
            fixed_root = Node(2, 0, root, game.getid())
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(reaches[child][1] * reaches[child][2] * fixed_reaches[child])
            free_root = Node(1, game.get_actual_set(1), root, game.getid())
            create_max_margin_gadget(free_root, cfvs, game, decomposition.get_subgame_roots(True)[sequence], reaches)
            root.children = [fixed_root, free_root]
            root.chance = [p, 1 - p]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "exp-strat":
        game = lp_solver.game
        fixed_reaches = get_reaches_from_strategy(opponent_strategy, game, 1)
        for sequence in decomposition.get_subgame_roots(True):
            game.root = Node(2, 0, None, game.getid())
            root = lp_solver.game.root
            fixed_root = Node(2, 0, root, game.getid())
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(reaches[child][1] * reaches[child][2] * fixed_reaches[child])
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            free_root = Node(1, game.get_actual_set(1), root, game.getid())
            create_max_margin_gadget(free_root, cfvs, game, decomposition.get_subgame_roots(True)[sequence], reaches)
            root.children = [fixed_root, free_root]
            root.chance = [p, 1 - p]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "unsafe":
        game = lp_solver.game
        for sequence in decomposition.get_subgame_roots(True):
            game.root = Node(2, 0, None, game.getid())
            root = lp_solver.game.root
            fixed_root = Node(2, 0, root, game.getid())
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(np.prod(reaches[child]))
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            free_root = Node(2, 0, root, game.getid())
            free_root.children = decomposition.get_subgame_roots(True)[sequence]
            for child in free_root.children:
                free_root.chance.append(np.prod(reaches[child]))
            root.children = [fixed_root, free_root]
            root.chance = [p, 1 - p]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                resulting_strategy[infoset] = policy_part[infoset]
    elif gadget == "comb":
        game = lp_solver.game
        for sequence in decomposition.get_subgame_roots(True):
            fixed_root = Node(2, 0, None, game.getid())
            game.root = fixed_root
            fixed_root.children = decomposition.get_subgame_roots(False)[sequence]
            for child in fixed_root.children:
                fixed_root.chance.append(np.prod(reaches[child]))
            for node in decomposition.get_subgame_nodes(False)[1][sequence]:
                node.player = 2
                node.chance = opponent_strategy[node.i_set]
            game.save_to_file("fixed_subgame.efg")
            lp_solver.solve(0)
            policy_part = lp_solver.strategy_in_cfr_format()[0]
            for infoset in decomposition.get_subgame_infosets(True)[0][sequence]:
                cdbr[infoset] = policy_part[infoset]
        combine_game = ExtensiveGame()
        combine_game.load(file_name)
        combination = Combination([cdbr, opponent_strategy], [full_nash, opponent_strategy], combine_game)
        # print(cdbr)
        resulting_strategy = combination.combine_strategies(p)[0]
    else:
        assert False, "Wrong gadget type"

    best_ne_solver = BestNEStatic.BestNashStatic(file_name, opponent_strategy)
    best_ne_solver.solve()
    best_ne = best_ne_solver.strategy_in_cfr_format()

    ### Evaluation part
    cfr = CFR(file_name)
    cfr.initialize()

    gain = value_of_the_game - cfr.compute_game_value([resulting_strategy, opponent_strategy])
    expl = -value_of_the_game + cfr.best_response(1, [resulting_strategy, {}])[0]

    if verbose:
        print()
        print_bold("Strategy gain and exploitability:")
        print(gain)
        print(expl)

        print()
        print_bold("Best response gain and exploitability:")
        print(value_of_the_game - cfr.best_response(0, [{}, opponent_strategy])[0])
        print(-value_of_the_game + cfr.best_response(1, cfr.best_response(0, [{}, opponent_strategy])[1])[0])

        print()
        print_bold("Nash gain and exploitability:")
        print(value_of_the_game - cfr.compute_game_value([equilibirum_strategy[0], opponent_strategy]))
        print(-value_of_the_game + cfr.best_response(1, equilibirum_strategy)[0])

        print()
        print_bold("Best Nash gain and exploitability:")
        print(value_of_the_game - cfr.compute_game_value([best_ne[0], opponent_strategy]))
        print(-value_of_the_game + cfr.best_response(1, best_ne)[0])

    return gain, expl
