# from xxlimited import new
from CFR import CFR
from ExtensiveGame import ExtensiveGame, Node
from ExtensiveSubgame import ExtensiveSubgame
from GenerateRnRGameNotFixed import GenerateRNRGameNotFixed
from DataManipulation import load_from_file, save_to_file
from SubgameCFR import SubgameCFR
import copy
import numpy as np


class CDRNR_G(CFR):
    def __init__(self, fname="data/RPS.efg", p=0.5, vf_iterations=10, strat_file=None, use_vf=True):
        rnr_generator = GenerateRNRGameNotFixed(fname)
        temp_fname = "data/cdrnr/CDRNR_test_" + "{:.1f}".format(p) + ".efg"
        if not strat_file:
            game = ExtensiveGame()
            game.load(fname)
            self.opponent_strategy = [np.zeros((len(game.root.children))), np.zeros(2), np.zeros(2), np.zeros(2), np.zeros(2)]
            self.opponent_strategy[0][-1] = 1.0
            self.opponent_strategy[1][-1] = 1.0
            self.opponent_strategy[2][-1] = 1.0
            self.opponent_strategy[3][-1] = 1.0
            self.opponent_strategy[4][-1] = 1.0
            print(self.opponent_strategy)
            save_to_file(self.opponent_strategy, "data/bad_strategies/gadget_game/gadget_game_0_iterations.strat")
        else:
            self.opponent_strategy = load_from_file(strat_file)[1]
        rnr_generator.generate(temp_fname, [p, 1 - p], 1, [self.opponent_strategy])
        super().__init__(temp_fname, cfr_plus=True)
        self.use_vf = use_vf
        self.p = p
        self.vf_iterations = vf_iterations
        self.reaches = {}
        self.reaches_b = {}
        self.avg_reaches = {}
        self.avg_reaches_b = {}
        self.vf_nodes = []
        self.value_function = {}
        self.fixed = [{}, {}]
        self.find_vf_nodes(self.game.root)

    def fix_opponent_strategy_first(self):
        self.fixed[1][self.game.root.children[1].i_set] = copy.deepcopy(self.opponent_strategy[0])
        self.strategy[1][self.game.root.children[1].i_set] = copy.deepcopy(self.opponent_strategy[0])
        self.average_strategy[1][self.game.root.children[1].i_set] = copy.deepcopy(self.opponent_strategy[0])

    def fix_opponent_strategy_second(self, node):
        if node.player == 3:
            return
        if node.player == 1:
            self.fixed[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
            self.average_strategy[1][node.i_set] = copy.deepcopy(self.opponent_strategy[node.i_set - len(self.opponent_strategy)])
        for child in node.children:
            self.fix_opponent_strategy_second(child)

    def find_vf_nodes(self, node):
        if node.player == 3:
            return
        else:
            for child in node.children:
                self.find_vf_nodes(child)
        if node.player == 0:
            self.vf_nodes.append(node)

    def solve_trunk_fully(self, iterations):
        super().solve(iterations)
        self.compute_reaches()

    def solve_trunk_partially(self, iterations):
        self.initialize()
        self.fix_opponent_strategy_first()
        for self.iteration in range(iterations):
            self.iter_count += self.iteration
            self.perform_iteration_full()

    def solve_rest_full(self, iterations):
        self.fix_opponent_strategy_second(self.game.root.children[1])
        # self.compute_avg_reaches()
        self.iter_count = 0
        super().solve(iterations, fixed=self.fixed)

    def solve_rest_chance(self, iterations):
        self.fix_opponent_strategy_second(self.game.root.children[1])
        self.compute_avg_reaches()
        self.add_chance_left()
        self.iter_count = 0
        super().solve(iterations, fixed=self.fixed)

    def solve_rest_gadget(self, iterations):
        self.fix_opponent_strategy_second(self.game.root.children[1])
        self.compute_avg_reaches()
        self.add_gadget_left()
        self.iter_count = 0
        super().solve(iterations, fixed=self.fixed)

    def perform_iteration_full(self):
        self.compute_reaches()
        self.compute_value_function()
        self.compute_counterfactual_vs()
        self.compute_regret_for_p1()  # Second player does not do anything
        self.compute_strategy_for_p1()
        self.average_strategies()

    def compute_value_function(self):
        subgame = ExtensiveSubgame(self.vf_nodes)
        subgame.change_reaches(self.reaches)
        subgame_cfr = SubgameCFR(subgame, cfr_player=0)
        subgame_cfr.solve(self.vf_iterations)
        for node in self.vf_nodes:
            self.value_function[node] = (subgame_cfr.counterfactual_values[node][0], subgame_cfr.counterfactual_values[node][1])

    def compute_value_avg_function(self):
        subgame = ExtensiveSubgame(self.vf_nodes)
        subgame.change_reaches(self.avg_reaches)
        subgame_cfr = SubgameCFR(subgame, cfr_player=0)
        subgame_cfr.solve(self.vf_iterations)
        for node in self.vf_nodes:
            self.value_function[node] = (subgame_cfr.counterfactual_values[node][0], subgame_cfr.counterfactual_values[node][1])

    # computing counterfactual values
    def compute_counterfactual_vs(self):
        self.counterfactual_values[self.game.root] = self.compute_cvs(self.game.root, 1, 1)

    def compute_cvs(self, node, reach0, reach1):
        if node.player == 2:
            p0val = 0
            p1val = 0
            for child, chance in zip(node.children, node.chance):
                temp0, temp1 = self.compute_cvs(child, reach0 * chance, reach1 * chance)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1
            return p0val, p1val
        elif node.player == 1:
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.value_function[child]
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        else:
            assert False

    def compute_regret_for_p1(self):
        player = 1
        for c in self.game.root.children:
            iset = c.i_set
            cfv = [0] * len(self.i_sets_to_nodes[player][iset][0].children)
            for node in self.i_sets_to_nodes[player][iset]:
                for i, child in enumerate(node.children):
                    cfv[i] += self.counterfactual_values[child][player]
                    # print(self.counterfactual_values[child][player], child)
            value = 0
            for i in range(len(cfv)):
                value += cfv[i] * self.strategy[player][iset][i]
            # print(player, cfv, value)
            for i, cfv_value in enumerate(cfv):
                self.counterfactual_regret[player][iset][i] += cfv_value - value
                if self.cfr_plus:
                    self.counterfactual_regret[player][iset][i] = max(0, self.counterfactual_regret[player][iset][i])

    def compute_strategy_for_p1(self):
        player = 1
        for c in self.game.root.children:
            iset = c.i_set
            if self.fixed is not None and iset in self.fixed[player]:
                continue
            sumed = np.sum(np.clip(self.counterfactual_regret[player][iset], 0, np.inf))
            if sumed > 0:
                for i, regret in enumerate(self.counterfactual_regret[player][iset]):
                    self.strategy[player][iset][i] = max(0, regret) / sumed
            else:
                self.set_uniform_strategy(player, iset, self.strategy)

    def set_uniform_elsewhere(self, pt_id):
        for node in self.public_tree[pt_id]:
            self.set_uniform(node)

    def set_uniform(self, node):
        if node.player == 3:
            return
        elif node.player == 2:
            for child in node.children:
                self.set_uniform(child)
        else:
            super().set_uniform_strategy(node.player, node.i_set, self.average_strategy)
            super().set_uniform_strategy(node.player, node.i_set, self.strategy)
            for child in node.children:
                self.set_uniform(child)

    def compute_exp_val(self, node):
        exp_val = 0
        if node.player == 3:
            return node.value
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                exp_val += self.compute_exp_val(child) * chance
        else:
            for i, child in enumerate(node.children):
                exp_val += self.compute_exp_val(child) * self.average_strategy[node.player][node.i_set][i]
        return exp_val

    def add_chance_left(self):
        new_chance = Node(2, 0, self.game.root, self.game.getid())
        for node in self.vf_nodes:
            if node.id_val >= self.game.root.children[1].id_val:
                continue
            new_chance.children.append(node)
            if self.p <= 1e-8:
                new_chance.chance.append(0.0)
            else:
                new_chance.chance.append(self.avg_reaches_b[node] / self.p)
        self.game.root.children[0] = new_chance

    def add_gadget_left(self):
        new_chance = Node(2, 0, self.game.root, self.game.getid())
        # for node in self.public_tree[pt_id]:
        for node in self.vf_nodes:
            if node.id_val >= self.game.root.children[1].id_val:
                continue
            self.game.isets_p2[str(self.game.actual_set_p2)] = self.game.actual_set_p2

            self.i_sets[1].append(self.game.actual_set_p2)
            decision_node = Node(1, self.game.actual_set_p2, new_chance, self.game.getid())
            terminate_node = Node(3, 0, decision_node, self.game.getid(), -self.value_function[node][1])
            # terminate_node = Node(3, 0, decision_node, self.game.getid(), self.compute_exp_val(node))
            decision_node.children.append(terminate_node)
            decision_node.children.append(node)
            node.parent = decision_node
            new_chance.chance.append(1.0)  # This is due to replacing only single p2 node after root chance
            new_chance.children.append(decision_node)
            self.i_sets_to_nodes[1][self.game.actual_set_p2] = [decision_node]
            self.set_uniform_strategy(player=1, iset=self.game.actual_set_p2, strategy=self.strategy)
            self.set_uniform_strategy(player=1, iset=self.game.actual_set_p2, strategy=self.average_strategy)
            self.game.actual_set_p2 += 1
        self.game.root.children[0] = new_chance

    def find_depths(self):
        depth_dict = {}
        self.find_depths_iter(self.game.root, depth_dict, 0)
        return depth_dict

    def find_depths_iter(self, node, depth_dict, depth):
        if not depth in depth_dict.keys():
            depth_dict[depth] = []
        depth_dict[depth].append(node)
        for c in node.children:
            self.find_depths_iter(c, depth_dict, depth + 1)

    def compute_reaches(self):
        self.compute_reaches_iter(self.game.root, 1.0, 1.0, 1.0)

    def compute_reaches_iter(self, node, reach1, reach2, reach):
        self.reaches[node] = (reach1, reach2)
        self.reaches_b[node] = reach
        if node.player == 3:
            return
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                self.compute_reaches_iter(child, reach1 * chance, reach2 * chance, reach * chance)
        elif node.player == 1:
            for i, child in enumerate(node.children):
                self.compute_reaches_iter(child, reach1 * self.strategy[1][node.i_set][i], reach2, reach * self.strategy[1][node.i_set][i])
        elif node.player == 0:
            for i, child in enumerate(node.children):
                self.compute_reaches_iter(child, reach1, reach2 * self.strategy[0][node.i_set][i], reach * self.strategy[0][node.i_set][i])

    def compute_avg_reaches(self):
        self.compute_avg_reaches_iter(self.game.root, 1.0, 1.0, 1.0)

    def compute_avg_reaches_iter(self, node, reach1, reach2, reach):
        self.avg_reaches[node] = (reach1, reach2)
        self.avg_reaches_b[node] = reach
        if node.player == 3:
            return
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                self.compute_avg_reaches_iter(child, reach1 * chance, reach2 * chance, reach * chance)
        elif node.player == 1:
            for i, child in enumerate(node.children):
                self.compute_avg_reaches_iter(child, reach1 * self.average_strategy[1][node.i_set][i], reach2, reach * self.average_strategy[1][node.i_set][i])
        elif node.player == 0:
            for i, child in enumerate(node.children):
                self.compute_avg_reaches_iter(child, reach1, reach2 * self.average_strategy[0][node.i_set][i], reach * self.average_strategy[0][node.i_set][i])

    def find_publics(self):
        public_sets = []
        for p2_c in self.game.root.children:
            ps = [p1_c for p1_c in p2_c.children]
            public_sets.append(ps)
        return public_sets
