from CFR import CFR
from ExtensiveGame import ExtensiveGame, Node
from ExtensiveSubgame import ExtensiveSubgame
from GenerateRnRGameNotFixed import GenerateRNRGameNotFixed
from DataManipulation import load_from_file, save_to_file
from SubgameCFR import SubgameCFR
import copy
import numpy as np

# Main Idea:
#   You fix the depth of tree to n
#   You do usual CFR up to the fixed depth
#   When you reach the depth in CFR, you create subgame out of the current info set.
#   Solve this subgame with different (usual) CFR. Use this as value function
#   Repeat this for i iterations
#   Result is strategy for player 1 up to the depth n
#   Now fix this strategy for player 1 up to the depth n
#   Follow this strategy up to some infoset I
#   Change the game in a following way:
#       Right subtree becomes just one chance node (both players are fixed) which results in I (other infosets are neglected)
#       Left subtree changes all actions nodes of P1 to chance nodes
#   Now we do CFR in this new tree. 
#   When in the left tree we reach different infoset, we do new CFR from this infoset to the end
#   Otherwise the CFR continues until the terminal
class CFRRNR(CFR):
    # TODO USE CFR Plus
    def __init__(self, fname="data/leduc_holdem.efg", p=0.5, vf_iterations = 10, strat_file="data/bad_strategies/leduc_holdem/leduc_holdem_21_iterations.strat", use_vf=True):
        rnr_generator = GenerateRNRGameNotFixed(fname)
        temp_fname = "data/cdrnr/leduc_holdem_" + "{:.1f}".format(p) + ".efg"
        self.opponent_strategy = load_from_file(strat_file)[1]
        rnr_generator.generate(temp_fname, [p, 1 - p], 1, [self.opponent_strategy])
        super().__init__(temp_fname)
        self.use_vf = use_vf
        self.p = p
        self.vf_iterations = vf_iterations
        self.public_chances = [] # This is here to quickly check if we reached the depth we wanted.
        self.find_public_chances(self.game.root, 0) # Finds 300 chances, which is expected
        self.reaches = {} # From average strategy
        self.reaches_both = {} # From immidiate strategy
        self.average_reaches = {}
        self.public_sets = []
        self.i_set_to_public = {}
        self.public_to_i_sets = {}
        self.value_function = {}
        self.exp_val_function = {}
        self.decompose_leduc() # Finds 30 public sets. Seems like 5 possible continuations from initial to second round * 6 cards on river
        self.i_sets_to_update = [set(), set()] # Only those infosets will be taken into account when computing strategies and regrets (to avoid using something outside of current scope)
        self.right_subtree_id = self.game.root.children[1].id_val
        self.current_subgame = 0
        self.final_strategies = []
        self.final_reaches = []
        self.final_average_strategies = []
        self.final_regrets = []
        self.final_values = []
        self.subgame_values = []
        self.trunk_i_sets = [set(), set()] # 18 infosets to player 1 (6 cards * 3 decisions before river) and 36 to player 2 (6 cards * 3 decision before river * 2 trees)
        self.subgame_i_sets = [[set(), set()] for _ in self.public_sets] # Each: 15 infosets to player 1 (5 cards since 1 is already in river * 3 decision before end)
                                                                         # Each: 30 infosets to player 2 (5 cards since 1 is already in river * 3 decision * 2 trees)
        self.organize_i_sets()
        self.trunk_solved = False
        self.fixed = [{}, {}]
        self.trunk_reaches= []

    def organize_i_sets(self):
        self.organize_trunk(self.game.root)
        for i, public in enumerate(self.public_sets):
            for root in public:
                self.organize_subgames(root, i)

    def organize_subgames(self, node, id):
        if node.player == 3:
            return
        if node.player == 0 or node.player == 1:
            self.subgame_i_sets[id][node.player].add(node.i_set)
        for child in node.children:
            self.organize_subgames(child, id)


    def organize_trunk(self, node):
        if node.player == 3:
            return
        if node.player == 2 and node.id_val in self.public_chances:
            return
        if node.player == 0 or node.player == 1:
            self.trunk_i_sets[node.player].add(node.i_set)
        for child in node.children:
            self.organize_trunk(child)


    def compute_strategy(self):
        self.compute_strategy_for_player(0)
        self.compute_strategy_for_player(1)

    def compute_strategy_for_player(self, player):
        for iset in self.i_sets_to_update[player]:
            if self.fixed is not None and iset in self.fixed[player]:
                # print(iset)
                continue
            sumed = np.sum(np.clip(self.counterfactual_regret[player][iset], 0, np.inf))
            if sumed > 0:
                for i, regret in enumerate(self.counterfactual_regret[player][iset]):
                    self.strategy[player][iset][i] = max(0, regret) / sumed
            else:
                self.set_uniform_strategy(player, iset, self.strategy)


    def compute_regret(self):
        self.compute_regret_for_player(0)
        self.compute_regret_for_player(1)

    def compute_regret_for_player(self, player):
        for iset in self.i_sets_to_update[player]:
            # print("Iset", iset)
            cfv = [0] * len(self.i_sets_to_nodes[player][iset][0].children)
            for node in self.i_sets_to_nodes[player][iset]:
                for i, child in enumerate(node.children):
                    cfv[i] += self.counterfactual_values[child][player]
                    # print(self.counterfactual_values[child][player], child)
            # if iset == 70 and player == 0:
            #     print(cfv)
            #     print(self.strategy[player][iset])
            value = 0
            for i in range(len(cfv)):
                value += cfv[i] * self.strategy[player][iset][i]
            # print(player, cfv, value)
            # if iset == 70 and player == 0:
            #     print(value)
            for i, cfv_value in enumerate(cfv):
                if self.cfr_plus:
                    self.counterfactual_regret[player][iset][i] += max(0, cfv_value - value)
                else:
                    self.counterfactual_regret[player][iset][i] += cfv_value - value


    def compute_exp_val(self, node):
        exp_val = 0
        if node.player == 3:
            return node.value 
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                exp_val += self.compute_exp_val(child) * chance
        else:
            for i, child in enumerate(node.children):
                exp_val += self.compute_exp_val(child) * self.average_strategy[node.player][node.i_set][i]
        return exp_val

    # def average_strategies(self):
    #     if self.iteration > 0:
    #         self.avg_visited = []
    #         self.avg_strat_at_node(self.game.root, [1, 1], [1, 1])
    #     else:
    #         for p, strat in enumerate(self.strategy):
    #             for i_set, s in strat.items():
    #                 if i_set in self.i_sets_to_update[p]:
    #                     self.average_strategy[p][i_set] = copy.deepcopy(s)
    #                     for l in range(len(self.average_strategy[p][i_set])):
    #                         self.average_strategy[p][i_set][l] = (0.9 * self.average_strategy[p][i_set][l]) + (0.1 / len(self.average_strategy[p][i_set]))
    #         # self.average_strategy = copy.deepcopy(self.strategy)

    # def avg_strat_at_node(self, node, player_reach, player_average_reach):
    #     player = node.player
    #     if player == 3:
    #         return
    #     elif player == 2:
    #         for child in node.children:
    #             self.avg_strat_at_node(child, player_reach, player_average_reach)
    #     else:
    #         if not node.i_set in self.i_sets_to_update[node.player]:
    #             return
    #         for i, child in enumerate(node.children):
    #             new_player_reach = [0, 0]
    #             new_player_average_reach = [0, 0]
    #             new_player_reach[1 - player] = player_reach[1 - player]
    #             new_player_reach[player] = player_reach[player] * self.strategy[player][node.i_set][i]

    #             new_player_average_reach[1 - player] = player_average_reach[1 - player]
    #             new_player_average_reach[player] = player_average_reach[player] * \
    #                                                self.average_strategy[player][node.i_set][i]
    #             # if child.i_set == 70 and child.player == 0:
    #             #     print(player_reach[1 - player])
    #             #     print(player_reach[player])
    #             #     print(self.strategy[player][node.i_set][i])
    #             #     print(player_average_reach[1 - player])
    #             #     print(player_average_reach[player])
    #             #     print(self.average_strategy[player][node.i_set][i])
    #             #     print("-----")
    #             self.avg_strat_at_node(child, new_player_reach, new_player_average_reach)
    #         if (player, node.i_set) not in self.avg_visited:
    #             self.avg_visited.append((player, node.i_set))
    #             self.average_in_is(player, node.i_set, player_reach[player], player_average_reach[player])
    #             # if node.i_set == 70 and player == 0:
    #             #     print(player_reach[player])
    #             #     print(player_average_reach[player])
    #             #     print(self.average_strategy[0][70])

    def find_public_chances(self, node, depth):
        if depth > 1 and node.player == 2:
            self.public_chances.append(node.id_val)
        elif node.player == 3:
            return
        else:
            for child in node.children:
                self.find_public_chances(child, depth+1)

    # Alwayses fixes second player (with id 1)
    def fix_opponent_strategy_first(self, node):
        if node.is_terminal():
            return
        if node.player == 2 and node.id_val in self.public_chances:
            return
        if node.player == 1:
            # node.player = 2
            # In the right subtree the infosets are duplicated therefore their ID is len(inf_sets) + i_set
            # So we remove the len(inf_set) to access the original strategy
            # Copy because we may be overwritting the strategy somewhere sometime, so this make sure we don't
            # Probably not necessary, I just don't want to think about it
            # node.chance = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.fixed[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.average_strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
        for child in node.children:
            self.fix_opponent_strategy_first(child)
            

    # Could use fix_opponent_strategy_first instead, it should work the same, but just to be sure.
    def fix_opponent_strategy_second(self, node):
        if node.is_terminal():
            return
        if node.player == 1:
            # node.player = 2
            # In the right subtree the infosets are duplicated therefore their ID is len(inf_sets) + i_set
            # So we remove the len(inf_set) to access the original strategy
            # Copy because we may be overwritting the strategy somewhere sometime, so this make sure we don't
            # node.chance = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.fixed[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.average_strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
        for child in node.children:
            self.fix_opponent_strategy_second(child)

    def fix_my_trunk_strategy(self, node):
        if node.is_terminal():
            return
        if node.player == 2 and node.id_val in self.public_chances:
            return
        if node.player == 0:
            # node.player = 2
            # Rather copy it, to be sure nothing bad happens to it
            # node.chance = copy.deepcopy(self.average_strategy[0][node.i_set])
            self.fixed[0][node.i_set] = copy.deepcopy(self.average_strategy[0][node.i_set])
            self.strategy[0][node.i_set] = copy.deepcopy(self.average_strategy[0][node.i_set])
            self.average_strategy[0][node.i_set] = copy.deepcopy(self.average_strategy[0][node.i_set])
        for child in node.children:
            self.fix_my_trunk_strategy(child)

    # Inserts given subgame to the root as a right subtree.
    # It uses reach probability as a chance node above.
    def add_chance_to_right(self, subgame):
        new_node = Node(2, 0, self.game.root, self.game.getid())
        # assert sum(chances) !=0
        for node in subgame:
            node.parent = new_node
            new_node.children.append(node)
            # Think about normalizing here, I am not sure if it could hapen that sum of chances is 0
            # new_node.chance.append(self.average_reaches[node][0] / (1 - self.p))
            # new_node.chance.append(self.reaches[node] / (1 - self.p))
            if self.p >= 0.99999:
                new_node.chance.append(0.0)
            else:
                new_node.chance.append(self.reaches[node] / (1 - self.p))

            # new_node.chance.append(self.reaches_both[node][0] / (1 - self.p))
        self.game.root.children[1] = new_node

    def add_chance_left(self, subgame):
        new_node = Node(2, 0, self.game.root, self.game.getid())
        for node in subgame:
            node.parent = new_node
            new_node.children.append(node)
            # Think about normalizing here, I am not sure if it could hapen that sum of chances is 0
            # new_node.chance.append(self.average_reaches[node][0] / (1 - self.p))
            # new_node.chance.append(self.reaches[node] / (1 - self.p))
            # new_node.chance.append(self.reaches[node] / self.p)
            if self.p <= 0.00001:
                new_node.chance.append(0.0)
            else:
                new_node.chance.append(self.reaches[node] / self.p)
            # new_node.chance.append(self.reaches_both[node][0] / (1 - self.p))
        self.game.root.children[0] = new_node

    def add_gadget_left(self, subgame):
        new_node = Node(2, 0, self.game.root, self.game.getid())
        # assert sum(chances) !=0
        # isets = [150, 150]
        iset_dict = {}
        cfvs = []
        cfvs_dict = {}
        cfvs_sum_dict = {}
        e = {}
        v = {}
        for node in subgame:
            if not node.children[0].i_set in iset_dict:
                iset_dict[node.children[0].i_set] = self.game.actual_set_p2
                self.game.isets_p2[str(self.game.actual_set_p2)] = self.game.actual_set_p2
                self.i_sets_to_nodes[1][self.game.actual_set_p2] = []
                self.i_sets[1].append(self.game.actual_set_p2)
                cfvs_dict[self.game.actual_set_p2] = 0
                cfvs_sum_dict[self.game.actual_set_p2] = 0
                e[self.game.actual_set_p2] = []
                v[self.game.actual_set_p2] = []
                self.game.actual_set_p2+=1
            # This will throw error since we do not have cfvs for this node. We always used Value Function
            # cfvs.append(self.counterfactual_values[node][0])
            # We shall compute value function for average reaches I guess.
            # cfvs.append(self.counterfactual_values[node.parent][0] * 0.25)
            e[iset_dict[node.children[0].i_set]].append(self.counterfactual_values[node])
            v[iset_dict[node.children[0].i_set]].append(self.average_reaches[node])
            if self.use_vf:
                cfvs.append(self.value_function[node][1])
                # Minus sign is here because the counterfactual value for the second player is negative in CFR (check CFRD line 372) (but we give the reward to the first player)
                cfvs_dict[iset_dict[node.children[0].i_set]] += self.exp_val_function[node]* self.average_reaches[node][1]
                cfvs_sum_dict[iset_dict[node.children[0].i_set]] += self.average_reaches[node][1]
            else:
                cfvs.append(self.counterfactual_values[node][1])
                cfvs_dict[iset_dict[node.children[0].i_set]] += self.compute_exp_val(node) * self.average_reaches[node][1]
                cfvs_sum_dict[iset_dict[node.children[0].i_set]] += self.average_reaches[node][1]
            
        # Not sure if we should use sum or average here... or anything else
        # This should be value of the infoset 
        # for item, val in e.items():
        #     print(item)
        #     print(val)
        #     print(v[item])
        #     print("--------")
        average_cfv = np.sum(cfvs)
        for node in subgame:

            decision_node = Node(1, iset_dict[node.children[0].i_set], new_node, self.game.getid())
            # Division by 4 is just to make it average. It works better
            terminate_node = Node(3, 0, decision_node, self.game.getid(), cfvs_dict[iset_dict[node.children[0].i_set]] / cfvs_sum_dict[iset_dict[node.children[0].i_set]])
            # terminate_node = Node(3, 0, decision_node, self.game.getid(), -self.counterfactual_values[node][1])
            node.parent = decision_node
            decision_node.children.append(terminate_node)
            decision_node.children.append(node)
            self.average_reaches[decision_node] = copy.deepcopy(self.average_reaches[node])
            new_node.children.append(decision_node)
            # Think about normalizing here, I am not sure if it could hapen that sum of chances is 0
            # new_node.chance.append(self.average_reaches[node][0] / (1 - self.p))
            # new_node.chance.append(self.reaches[node] / (1 - self.p))
            if self.p <= 0.0000001:
                
                new_node.chance.append(0.0)
            else:
                new_node.chance.append(self.reaches[node] / self.p)
            # new_node.chance.append(self.reaches_both[node][0] / (1 - self.p))
            self.i_sets_to_nodes[1][iset_dict[node.children[0].i_set]].append(decision_node)
        self.game.root.children[0] = new_node
        for iset in iset_dict.values():
            self.set_uniform_strategy(player=1, iset=iset, strategy=self.strategy)


    def compute_reaches(self):
        self.compute_reaches_step(self.game.root, reach=1)

    def compute_reaches_step(self, node, reach):
        self.reaches[node] = reach
        if node.player == 3:
            return
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                self.compute_reaches_step(child, reach * chance)
        else:
            for i, child in enumerate(node.children):
                self.compute_reaches_step(child, reach * self.average_strategy[node.player][node.i_set][i])
            
            

    def compute_average_reaches(self):
        self.compute_average_reaches_both(self.game.root, p1_reach=1, p2_reach=1)

    def compute_average_reaches_both(self, node, p1_reach, p2_reach):
        self.average_reaches[node] = (p1_reach, p2_reach)
        if node.player == 3:
            return
        else:
            if node.player == 2:
                for chance, child in zip(node.chance, node.children):
                    self.compute_average_reaches_both(child, p1_reach * chance, p2_reach * chance)
            elif node.player == 1:
                for i, child in enumerate(node.children):
                    self.compute_average_reaches_both(child, p1_reach * self.average_strategy[1][node.i_set][i], p2_reach)
            elif node.player == 0:
                for i, child in enumerate(node.children):
                    self.compute_average_reaches_both(child, p1_reach, p2_reach * self.average_strategy[0][node.i_set][i])

    def compute_reaches_both(self):
        self.compute_reaches_both_step(self.game.root, p1_reach=1, p2_reach=1)

    def compute_reaches_both_step(self, node, p1_reach, p2_reach):
        self.reaches_both[node] = (p1_reach, p2_reach)
        if node.player == 3:
            return
        else:
            if node.player == 2:
                for chance, child in zip(node.chance, node.children):
                    if self.trunk_solved and node.id_val == self.game.root.children[1].id_val:
                        # print("eeee")
                        if self.p >= 0.999999:
                            self.compute_reaches_both_step(child, 0.0, 0.0)
                        else:
                            self.compute_reaches_both_step(child, p1_reach * self.average_reaches[child][0] / (1 - self.p), p2_reach * self.average_reaches[child][0] / (1 - self.p))
                    else:
                        self.compute_reaches_both_step(child, p1_reach * chance, p2_reach * chance)
            elif node.player == 1:
                for i, child in enumerate(node.children):
                    self.compute_reaches_both_step(child, p1_reach * self.strategy[1][node.i_set][i], p2_reach)
            elif node.player == 0:
                for i, child in enumerate(node.children):
                    self.compute_reaches_both_step(child, p1_reach, p2_reach * self.strategy[0][node.i_set][i])

    def solve_leduc_trunk(self, iterations=1000):
        self.initialize()

        # First child is always the BR one. Others are fixed. We just use it for one strategy
        self.fix_opponent_strategy_first(self.game.root.children[1])
        # self.check_right_tree_initial()
        for self.iteration in range(iterations):
            self.iter_count += self.iteration
            print(self.iteration)
            self.perform_iteration()
            
        # Right subtree is going to be changed, therefore it does not make sense to care about it.
        print("Solved trunk")
        # From this point it could probably be in the solve_leduc_rest, but we do it for each subgame, so it seems logical to make it just once.
        self.fix_my_trunk_strategy(self.game.root.children[0])
        # self.check_left_tree_trunk()
        # Store it just in case it is needed later, I am not sure about it now.
        self.compute_reaches()
        self.compute_reaches_both()
        self.compute_average_reaches()
        self.save_trunk()
        self.trunk_reaches = copy.deepcopy(self.reaches_both)
        self.trunk_solved = True
        self.strategy = copy.deepcopy(self.average_strategy)
        self.compute_counterfactual_values_first()
        

    def solve_leduc_rest_with_gadget(self, public_set, i_public_set, iterations=1000):
        # for i_public_set, public_set in enumerate(self.public_sets):
        self.i_sets_to_update = [set(), set()]
        self.iter_count = 0
        # Change game tree in such a way that the right one is just one chance node and then one infoset.
        # Take only the right part from root.
        # self.value_function_cache = [[] for i in range(len(self.public_sets) * 2)]
        # I am not sure about the right_subtree_id, it does not seem robust, but nothing is.
        only_left_root = [i for i in public_set if i.id_val < self.right_subtree_id]
        only_right_root = [i for i in public_set if i.id_val > self.right_subtree_id]
        self.current_subgame = i_public_set
        
        # else:
        # for a1 in 

        self.add_chance_to_right(only_right_root)
            
        for root_node in only_right_root:
            self.fix_opponent_strategy_second(root_node)
        # We use value function to get counterfactual values for each subgame.
        if self.use_vf:
            self.presolve_value_function_gadget()
        self.add_gadget_left(only_left_root)
        self.initialize_regrets()

        self.reaches = {} # From average strategy
        self.reaches_both = {} # From immidiate strategy
        self.value_function = {}
        self.counterfactual_values = {}

        for self.iteration in range(iterations):
            # print(self.iteration)
            self.iter_count += self.iteration # Does this even make sense in this scenario?
            self.perform_iteration_with_gadget()
        # for iset in range(936, 941):
            # print(self.strategy[1][node.i_set])
            # print(self.average_strategy[1][iset])
        # self.subgame_values.append(self.value_of_the_game(self.game.root, 1))
        # print("Value function cache: " + str(self.already_found))
        print("Solved subgame with id: " + str(i_public_set))


    def solve_leduc_rest_without_gadget(self, public_set, i_public_set, iterations=1000):
        # for i_public_set, public_set in enumerate(self.public_sets):
        self.i_sets_to_update = [set(), set()]
        self.iter_count = 0
        # Change game tree in such a way that the right one is just one chance node and then one infoset.
        # Take only the right part from root.
        # self.value_function_cache = [[] for i in range(len(self.public_sets) * 2)]
        # I am not sure about the right_subtree_id, it does not seem robust, but nothing is.
        only_left_root = [i for i in public_set if i.id_val < self.right_subtree_id]
        only_right_root = [i for i in public_set if i.id_val > self.right_subtree_id]
        self.current_subgame = i_public_set
        self.add_chance_to_right(only_right_root)
            
        for root_node in only_right_root:
            self.fix_opponent_strategy_second(root_node)
        # We use value function to get counterfactual values for each subgame.
        if self.use_vf:
            self.presolve_value_function_gadget()
        self.add_chance_left(only_left_root)
        self.initialize_regrets()

        self.reaches = {} # From average strategy
        self.reaches_both = {} # From immidiate strategy
        self.value_function = {}
        self.counterfactual_values = {}

        for self.iteration in range(iterations):
            # print(self.iteration)
            self.iter_count += self.iteration # Does this even make sense in this scenario?
            self.perform_iteration_without_gadget()
        # self.subgame_values.append(self.value_of_the_game(self.game.root, 1))
        # print("Value function cache: " + str(self.already_found))
        print("Solved subgame with id: " + str(i_public_set))


    def solve_leduc_rest(self, public_set, i_public_set, iterations=1000):
        # for i_public_set, public_set in enumerate(self.public_sets):
        self.i_sets_to_update = [set(), set()]
        self.initialize_regrets()
        self.iter_count = 1
        # Change game tree in such a way that the right one is just one chance node and then one infoset.
        # Take only the right part from root.
        # self.value_function_cache = [[] for i in range(len(self.public_sets) * 2)]
        # I am not sure about the right_subtree_id, it does not seem robust, but nothing is.
        only_right_root = [i for i in public_set if i.id_val > self.right_subtree_id]
        self.current_subgame = i_public_set
        self.add_chance_to_right(only_right_root)
            
        for root_node in only_right_root:
            self.fix_opponent_strategy_second(root_node)

        self.reaches = {} # From average strategy
        self.reaches_both = {} # From immidiate strategy
        self.value_function = {}
        self.counterfactual_values = {}

        for self.iteration in range(iterations):
            # print(self.iteration)
            self.iter_count += self.iteration # Does this even make sense in this scenario?
            self.perform_iteration_with_fixed_trunk()
        self.subgame_values.append(self.value_of_the_game(self.game.root, 1))
        self.save_subgame_result()
        # print("Value function cache: " + str(self.already_found))
        print("Solved subgame with id: " + str(i_public_set))

    def save_subgame_result(self):
        strategy = [{}, {}]
        average_strategy = [{}, {}]
        reaches = {}
        regrets = [{}, {}]
        values = {}
        for i in range(2):
            for iset, strat in self.strategy[i].items():
                strategy[i][iset] = copy.deepcopy(strat)
            for iset, strat in self.average_strategy[i].items():
                average_strategy[i][iset] = copy.deepcopy(strat)
            for iset, regret in self.counterfactual_regret[i].items():
                regrets[i][iset] = copy.deepcopy(regret)
            for node, reach in self.reaches.items():
                reaches[node] = reach
            for node, value in self.counterfactual_values.items():
                values[node] = copy.deepcopy(value)
        self.final_strategies.append(strategy)
        self.final_average_strategies.append(average_strategy)
        self.final_reaches.append(reaches)
        self.final_regrets.append(regrets)
        self.final_values.append(values)


    def load_back_trunk_values(self):
        self.strategy = [{}, {}]
        self.average_strategy = [{}, {}]
        self.reaches = {}
        self.counterfactual_regret = [{}, {}]
        self.counterfactual_values = {}
        for i in range(2):
            for iset, strat in self.trunk_strategy[i].items():
                self.strategy[i][iset] = copy.deepcopy(strat)
            for iset, strat in self.trunk_average_strategy[i].items():
                self.average_strategy[i][iset] = copy.deepcopy(strat)
            for iset, regret in self.trunk_regret[i].items():
                self.counterfactual_regret[i][iset] = copy.deepcopy(regret)
            for node, reach in self.trunk_reaches.items():
                self.reaches[node] = reach # May be useless
            for node, value in self.trunk_values.items():
                self.counterfactual_values[node] = copy.deepcopy(value)

    def save_trunk(self):
        self.trunk_strategy = [{}, {}]
        self.trunk_average_strategy = [{}, {}]
        self.trunk_reaches = {}
        self.trunk_regret = [{}, {}]
        self.trunk_values = {}
        for i in range(2):
            for iset, strat in self.strategy[i].items():
                self.trunk_strategy[i][iset] = copy.deepcopy(strat)
            for iset, strat in self.average_strategy[i].items():
                self.trunk_average_strategy[i][iset] = copy.deepcopy(strat)
            for iset, regret in self.counterfactual_regret[i].items():
                self.trunk_regret[i][iset] = copy.deepcopy(regret)
            for node, reach in self.reaches.items():
                self.trunk_reaches[node] = reach
            for node, value in self.counterfactual_values.items():
                self.trunk_values[node] = copy.deepcopy(value)

    def perform_iteration_without_gadget(self):
        # self.compute_reaches()
        # self.reaches_both = copy.deepcopy(self.trunk_reaches)
        # self.compute_reaches_both()
        # self.presolve_value_function_second()
        self.compute_counterfactual_values_without_gadget()
        self.i_sets_to_update[0] = copy.deepcopy(self.subgame_i_sets[self.current_subgame][0])
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()

        # print(self.counterfactual_values)
        if self.save_strategy:
            self.strategy_accumulator['cur'].append(copy.deepcopy(self.strategy))
            self.strategy_accumulator['avg'].append(copy.deepcopy(self.average_strategy))

    def perform_iteration_with_gadget(self):
        # self.compute_reaches()
        # self.reaches_both = copy.deepcopy(self.trunk_reaches)
        # self.compute_reaches_both()
        # self.presolve_value_function_second()
        self.compute_counterfactual_values_gadget()
        self.i_sets_to_update[0] = copy.deepcopy(self.subgame_i_sets[self.current_subgame][0])
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()

        # print(self.counterfactual_values)
        if self.save_strategy:
            self.strategy_accumulator['cur'].append(copy.deepcopy(self.strategy))
            self.strategy_accumulator['avg'].append(copy.deepcopy(self.average_strategy))
    

    def perform_iteration_with_fixed_trunk(self):
        # self.compute_reaches()
        # self.reaches_both = copy.deepcopy(self.trunk_reaches)
        self.compute_reaches_both()
        if self.use_vf:
            self.presolve_value_function_second()
        self.compute_counterfactual_values_second()
        self.i_sets_to_update[0] = copy.deepcopy(self.subgame_i_sets[self.current_subgame][0])
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()

        # print(self.counterfactual_values)
        if self.save_strategy:
            self.strategy_accumulator['cur'].append(copy.deepcopy(self.strategy))
            self.strategy_accumulator['avg'].append(copy.deepcopy(self.average_strategy))
    

    def perform_iteration(self):
        # self.compute_reaches()
        self.compute_reaches_both()
        if self.use_vf:
            self.presolve_value_function()
        self.compute_counterfactual_values_first()
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()
        if self.save_strategy:
            self.strategy_accumulator['cur'].append(copy.deepcopy(self.strategy))
            self.strategy_accumulator['avg'].append(copy.deepcopy(self.average_strategy))
        


    def compute_counterfactual_values_first(self):
        self.counterfactual_values[self.game.root] = self.compute_cfvs_first(self.game.root, 1, 1)

    def compute_counterfactual_values_second(self):
        self.counterfactual_values[self.game.root] = self.compute_cfvs_second(self.game.root, 1, 1)

    def compute_counterfactual_values_gadget(self):
        self.counterfactual_values[self.game.root] = self.compute_cfvs_gadget(self.game.root, 1, 1)

    def compute_counterfactual_values_without_gadget(self):
        self.counterfactual_values[self.game.root] = self.compute_cfvs_without_gadget(self.game.root, 1, 1)

    def compute_cfvs_first(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif node.player == 1:
            # This is here to make sure only reachable info sets have updated strategies and regrets
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_first(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_first(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val
        else:
            p0val = 0
            p1val = 0
            if self.use_vf and node.id_val in self.public_chances:
                for child, chance in zip(node.children, node.chance):
                    # self.i_sets_to_update[child.player].add(child.i_set)
                    temp0, temp1 = self.value_function[child] # This is already counterfactual value
                    # With chance it should work similarly to CFRD
                    temp0 *= reach0 * chance
                    temp1 *= reach1 * chance
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                    # Create new game for CFR
            else:
                for child, chance in zip(node.children, node.chance):
                    temp0, temp1 = self.compute_cfvs_first(child, reach0 * chance, reach1 * chance)
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
            return p0val, p1val

    def compute_cfvs_second(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif node.player == 1:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            # if node.i_set == 15:
                # print(node)
                # print(self.current_subgame)
                # print(self.public_to_i_sets[self.current_subgame])
                # print(self.i_sets_to_update[node.player])
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_second(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val
        else:
            p0val = 0
            p1val = 0
            # self.compute_reaches()
            # if node.id_val == self.root.children[1].id_val:
            #     for child, chance in zip(node.chi)
            for child, chance in zip(node.children, node.chance):
                if node.id_val == self.game.root.children[1].id_val:
                    if self.p >= 0.999999:
                        temp0, temp1 = self.compute_cfvs_second(child, 0.0, 0.0)
                    else:
                        temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0] / (1 - self.p), reach1 * self.average_reaches[child][1] / (1 - self.p))
                    # temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0], reach1 * self.average_reaches[child][1])
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                elif (not node.id_val in self.public_chances) or child.i_set in self.public_to_i_sets[self.current_subgame] or not self.use_vf:
                    temp0, temp1 = self.compute_cfvs_second(child, reach0 * chance, reach1 * chance)
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                else:
                    # self.i_sets_to_update[child.player].add(child.i_set)
                    temp0, temp1 = self.value_function[child]
                    temp0 *= reach0 * chance
                    temp1 *= reach1 * chance
                    # temp0 *= reach0 * chance
                    # temp1 *= reach1 * chance
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
            return p0val, p1val

    def compute_cfvs_gadget(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif node.player == 1:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_gadget(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val
        else:
            p0val = 0
            p1val = 0
            # self.compute_reaches()
            # if node.id_val == self.root.children[1].id_val:
            #     for child, chance in zip(node.chi)
            for child, chance in zip(node.children, node.chance):
                if node.id_val == self.game.root.children[1].id_val:
                    if self.p >= 0.99999:
                        temp0, temp1 = self.compute_cfvs_gadget(child, 0.0, 0.0)
                    else:
                        temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * self.average_reaches[child][0] / (1 - self.p), reach1 * self.average_reaches[child][1] / (1 - self.p))
                    # temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0], reach1 * self.average_reaches[child][1])
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                elif node.id_val == self.game.root.children[0].id_val:
                    # Could be shortened in previous condition.
                    if self.p <= 0.0000001:
                        temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * (1 / 120), 0.0)
                    else:
                        temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * (1 / 120), reach1 * self.average_reaches[child][1] / self.p) # This seems weird but it is here to prevail the chances after root.
                    # temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * self.average_reaches[child][0] / (1 - self.p), reach1 * self.average_reaches[child][1] / (1 - self.p))
                    # temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0], reach1 * self.average_reaches[child][1])
                    
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                else:
                    temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * chance, reach1 * chance)
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
            return p0val, p1val


    def compute_cfvs_without_gadget(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif node.player == 1:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_without_gadget(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            self.i_sets_to_update[node.player].add(node.i_set)
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs_without_gadget(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val
        else:
            p0val = 0
            p1val = 0
            # self.compute_reaches()
            # if node.id_val == self.root.children[1].id_val:
            #     for child, chance in zip(node.chi)
            for child, chance in zip(node.children, node.chance):
                if node.id_val == self.game.root.children[1].id_val:
                    if self.p >= 0.999999:
                        temp0, temp1 = self.compute_cfvs_without_gadget(child, 0.0, 0.0)
                    else:
                        temp0, temp1 = self.compute_cfvs_without_gadget(child, reach0 * self.average_reaches[child][0] / (1 - self.p), reach1 * self.average_reaches[child][1] / (1 - self.p))
                    # temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0], reach1 * self.average_reaches[child][1])
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                elif node.id_val == self.game.root.children[0].id_val:
                    # Could be shortened in previous condition.
                    if self.p <= 0.00000001:
                        temp0, temp1 = self.compute_cfvs_without_gadget(child, 0.0, 0.0)
                    else:
                        temp0, temp1 = self.compute_cfvs_without_gadget(child, reach0 * self.average_reaches[child][0] / (self.p), reach1 * self.average_reaches[child][1] / (self.p)) # This seems weird but it is here to prevail the chances after root.
                    # temp0, temp1 = self.compute_cfvs_gadget(child, reach0 * self.average_reaches[child][0] / (1 - self.p), reach1 * self.average_reaches[child][1] / (1 - self.p))
                    # temp0, temp1 = self.compute_cfvs_second(child, reach0 * self.average_reaches[child][0], reach1 * self.average_reaches[child][1])
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
                else:
                    temp0, temp1 = self.compute_cfvs_without_gadget(child, reach0 * chance, reach1 * chance)
                    self.counterfactual_values[child] = (temp0, temp1)
                    p0val += temp0
                    p1val += temp1
            return p0val, p1val

    def presolve_value_function(self):
        for public_set in self.public_sets:
            # self.perform_subgame_cfr(public_set)
            left_public_set = [i for i in public_set if i.id_val < self.right_subtree_id]
            right_public_set = [i for i in public_set if i.id_val >= self.right_subtree_id]
            self.perform_subgame_cfr(left_public_set)
            self.perform_subgame_cfr(right_public_set)

    def presolve_value_function_second(self):
        for public_set in self.public_sets:
            # self.perform_subgame_cfr(public_set)
            # Maybe should be together, but I am not sure how to do it now.
            new_public_set = [i for i in public_set if i.id_val < self.right_subtree_id]
            self.perform_subgame_cfr(new_public_set)

    def presolve_value_function_gadget(self):
        for public_set in self.public_sets:
            # self.perform_subgame_cfr(public_set)
            # Maybe should be together, but I am not sure how to do it now.
            new_public_set = [i for i in public_set if i.id_val < self.right_subtree_id]
            self.perform_average_subgame_cfr(new_public_set)

    def perform_average_subgame_cfr(self, public_set):
        subgame = ExtensiveSubgame(public_set)
        subgame.change_reaches(self.average_reaches)
        subgame_cfr = SubgameCFR(subgame, cfr_player=0)
        subgame_cfr.solve(self.vf_iterations)
        for root_node in public_set:
            self.value_function[root_node] = (subgame_cfr.counterfactual_values[root_node][0], subgame_cfr.counterfactual_values[root_node][1])
            self.exp_val_function[root_node] = subgame_cfr.compute_expected(root_node)
            

    def perform_subgame_cfr(self, public_set):
        subgame = ExtensiveSubgame(public_set)

        ############# Just testing
        # chances_p1 = []
        # chances_p2 = []
        # for root in public_set:
        #     chances_p1.append(self.reaches_both[root][0])
        #     chances_p2.append(self.reaches_both[root][1])
        #     # chances.append(self.reaches[root])
        # np_chances_p1 = np.array(chances_p1)
        # np_chances_p2 = np.array(chances_p2)
        # i_public_set = self.i_set_to_public[public_set[0].i_set] 
        # if public_set[0].id_val >= self.right_subtree_id:
        #     i_public_set += len(self.public_sets)
        # add_chances = True
        # for a in self.value_function_cache[i_public_set]:
        #     v = np.linalg.norm(np_chances_p1 - a[0], ord=np.inf)
        #     v2 = np.linalg.norm(np_chances_p2 - a[1], ord=np.inf)
        #     if v < 1e-12 and v2 < 1e-12:
        #         self.already_found += 1
        #         add_chances = False
        #         break
        # if add_chances:
        #     self.value_function_cache[i_public_set].append((np_chances_p1, np_chances_p2))
        ############# Just testing
        # subgame.join_with_chance_node(chances)
        # temp_file = "data/temporary/temp_game_" + str(self.current_subgame) + "_"+ str(public_set[0].id_val) + ".efg" # Should be only used here so it does not matter
        # subgame.save_to_file(temp_file)
        # cfr = CFR(temp_file)
        # cfr.solve(iterations = self.vf_iterations)
        # sg = ExtensiveSubgame(public_set)
        subgame.change_reaches(self.reaches_both)
        subgame_cfr = SubgameCFR(subgame, cfr_player=0)
        subgame_cfr.solve(self.vf_iterations)
        # val = subgame_cfr.solve(1000)
        # This is sketchy. Not sure if this does what we expect it to do. Because now it is counterfactual value from strategy.
        # Is it the same as treating it as terminal? I am not sure, treating this as terminal may be safer. But I am not sure.
        # I believe that the counterfactual value is the value of the game (from this state) multiplied by reach probability of both playera.
        # cfr.strategy = cfr.average_strategy
        # cfr.compute_counterfactual_values()
        # cfr.average_strategy
        for root_node in public_set:
            # print(self.value_function[root_node])
            self.value_function[root_node] = (subgame_cfr.counterfactual_values[root_node][0], subgame_cfr.counterfactual_values[root_node][1])
            # print(self.value_function[root_node])
            

    # Finds all subgames in Leduc whichs starts after chance node (dealing of the card)
    def decompose_leduc(self):
        p1_infosets = {}
        p2_infosets = {}
        self.decompose_leduc_step(self.game.root, p1_infosets, p2_infosets, 0)
        already_handled = set()
        for nodes in p1_infosets.values():
            for node in nodes:
                if not node.id_val in already_handled:
                    public_set = []
                    self.decompose_dfs(node, public_set, already_handled, p1_infosets, p2_infosets)
                    self.public_sets.append(public_set)
        for i_ps, ps in enumerate(self.public_sets):
            self.public_to_i_sets[i_ps] = []
            for node in ps:
                self.i_set_to_public[node.i_set] = i_ps
                if not node.i_set in self.public_to_i_sets[i_ps]:
                    self.public_to_i_sets[i_ps].append(node.i_set)

    def decompose_dfs(self, node, public_set, already_handled, p1_infosets, p2_infosets):
        if node.id_val in already_handled:
            return
        already_handled.add(node.id_val)
        if node.player == 0:
            public_set.append(node)
            for i_node in p1_infosets[node.i_set]:
                self.decompose_dfs(i_node, public_set, already_handled, p1_infosets, p2_infosets)
            for child in node.children:
                if child.player == 1:
                    self.decompose_dfs(child, public_set, already_handled, p1_infosets, p2_infosets)
        else:
            for i_node in p2_infosets[node.i_set]:
                self.decompose_dfs(i_node.parent, public_set, already_handled, p1_infosets, p2_infosets)
                   


    def decompose_leduc_step(self, node, p1_infosets, p2_infosets, depth):
        if depth > 0 and node.player == 2:
            for child in node.children:
                if child.i_set not in p1_infosets:
                    p1_infosets[child.i_set] = []
                p1_infosets[child.i_set].append(child)
                for child2 in child.children:
                    if child2.player == 1:
                        if child2.i_set not in p2_infosets:
                            p2_infosets[child2.i_set] = []
                        p2_infosets[child2.i_set].append(child2)
            return
        elif node.player == 2:
            for child in node.children:
                self.decompose_leduc_step(child, p1_infosets, p2_infosets, depth)
        elif node.player == 3:
            return
        else:
            for child in node.children:
                self.decompose_leduc_step(child, p1_infosets, p2_infosets, depth+1)


    def check_right_tree_initial(self):
        self.check_right_tree(self.game.root.children[1])
    
    def check_right_tree(self, node):
        assert node.player != 1
        if node.player == 3:
            return
        elif node.player == 2 and node.id_val in self.public_chances:
            return
        else:
            for child in node.children:
                self.check_right_tree(child)

    def check_left_tree_trunk(self):
        self.check_left_tree(self.game.root.children[0])
    
    def check_left_tree(self, node):
        assert node.player != 0
        if node.player == 3:
            return
        elif node.player == 2 and node.id_val in self.public_chances:
            return
        else:
            for child in node.children:
                self.check_left_tree(child)

    def find_isets_before_public(self, node, public):
        if node.player == 3:
            return
        if node.player == 2 and node.id_val in self.public_chances:
            return
        if node.player == 0 or node.player == 1:
            public[node.player].add(node.i_set)
        for child in node.children:
            self.find_isets_before_public(child, public)
