import numpy as np
import copy


class SubgameCFR:
    def __init__(self, subgame, cfr_player):
        self.subgame = subgame
        self.iter_count = 0
        # this will hold which game has which isets for both players
        self.i_sets = [[], []]
        # this holds which nodes belong to each iset
        self.i_sets_to_nodes = [{}, {}]
        # counterfactual values in node
        self.counterfactual_values = {}
        self.counterfactual_regret = [{}, {}]
        # strategies
        self.strategy = [{}, {}]
        self.average_strategy = None
        # response buffers for values, reaches and completed isets
        self.iset_tree = [{}, {}]
        self.response_reaches = {}
        # holding actual game value based on both player strategies
        self.game_value = 0
        self.iteration = 0
        self.avg_visited = None
        # parameters (right now unsetable)
        self.cfr_plus = True
        self.progressive_strategy = True
        self.counterfactual_values_based_on_average = True
        self.cfr_player = cfr_player

    def solve(self, iterations):
        # for root in self.subgame.roots:
        #     print(self.subgame.reaches[root])
        self.initialize()
        for self.iteration in range(1, iterations + 1):
            self.iter_count += self.iteration
            self.perform_iteration()
        # to compute cfvs from average strategy and not current (can be switched for cfrqr)
        self.modify_counterfactual_values()
        # print(self.strategy)
        # for root in self.subgame.roots:
        #     print(self.counterfactual_values[root])
        return self.game_value

    def modify_counterfactual_values(self):
        # print(self.subgame.reaches)
        # print(self.average_strategy)
        # print(self.strategy)
        self.strategy = self.average_strategy
        # backup_strategy = copy.deepcopy(self.strategy)
        # self.best_response(self.cfr_player)
        # # print("First CBR", self.strategy)
        self.compute_counterfactual_values()
        # resulting_cfvs = copy.deepcopy(self.counterfactual_values)
        # # print(resulting_cfvs)
        #
        # self.strategy = backup_strategy
        # self.best_response(1 - self.cfr_player)
        # # print("Second CBR", self.strategy)
        # self.compute_counterfactual_values()
        # for node in self.counterfactual_values.keys():
        #     if self.cfr_player == 1:
        #         self.counterfactual_values[node] = resulting_cfvs[node][0], \
        #                                            self.counterfactual_values[node][self.cfr_player]
        #     else:
        #         self.counterfactual_values[node] = self.counterfactual_values[node][self.cfr_player], \
        #                                            resulting_cfvs[node][1]
        # for root in self.subgame.roots:
        #     print(root, self.counterfactual_values[root], end=", ")
        # print()

    def perform_iteration(self):
        self.compute_counterfactual_values()
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()

    def initialize(self):
        self.create_isets()
        self.create_iset_tree()
        self.initialize_strategy()
        self.initialize_regrets()
        self.iter_count = 0

    #### INITIALIZATION ####
    # constructing the cfr
    def create_isets(self):
        # create empty array for isets in trunk for both players
        for root in self.subgame.roots:
            self.put_to_isets(root, depth=0)

    def put_to_isets(self, node, depth):
        # gather variables from the node for better understanding
        player = node.player
        if player > 2:
            return
        if player == 2:
            for child in node.children:
                self.put_to_isets(child, depth=depth + 1)
        else:
            iset = node.i_set
            # put node to iset and create new iset list if it does not exist yet
            if iset not in self.i_sets_to_nodes[player]:
                self.i_sets_to_nodes[player][iset] = []
            self.i_sets_to_nodes[player][iset].append(node)
            # put iset to subgame
            if iset not in self.i_sets[player]:
                self.i_sets[player].append(iset)
            for child in node.children:
                self.put_to_isets(child, depth=depth + 1)

    # strategy initialization
    def initialize_random_strategy(self):
        self.initialize_strategy_randomly(player=1)
        self.initialize_strategy_randomly(player=0)
        self.average_strategy = copy.deepcopy(self.strategy)

    def initialize_strategy_randomly(self, player):
        for iset in self.i_sets[player]:
            self.set_random_strategy(player=player, iset=iset, strategy=self.strategy)

    def set_random_strategy(self, player, iset, strategy):
        sample_node = self.i_sets_to_nodes[player][iset][0]
        n_children = len(sample_node.children)
        strat = np.random.rand(n_children)
        strategy[player][iset] = strat / np.sum(strat)

    def initialize_strategy(self):
        self.initialize_strategy_uniformly(player=1)
        self.initialize_strategy_uniformly(player=0)
        self.average_strategy = copy.deepcopy(self.strategy)

    def initialize_strategy_uniformly(self, player):
        for iset in self.i_sets[player]:
            self.set_uniform_strategy(player=player, iset=iset, strategy=self.strategy)

    def set_uniform_strategy(self, player, iset, strategy):
        sample_node = self.i_sets_to_nodes[player][iset][0]
        n_children = len(sample_node.children)
        strategy[player][iset] = [x / n_children for x in [1] * n_children]

    # regret initialization
    def initialize_regrets(self):
        self.initialize_regret_for_player(0)
        self.initialize_regret_for_player(1)

    def initialize_regret_for_player(self, player):
        for iset in self.i_sets[player]:
            self.counterfactual_regret[player][iset] = [0] * len(self.i_sets_to_nodes[player][iset][0].children)

    # create iset tree
    def create_iset_tree(self):
        for root in self.subgame.roots:
            self.create_iset_tree_step(root, (-1, -1), (0, 0))

    def create_iset_tree_step(self, node, isets, actions):
        player = node.player
        if player == 3:
            for player_index in range(2):
                if isets[player_index] not in self.iset_tree[player_index]:
                    self.iset_tree[player_index][isets[player_index]] = {}
                if actions[player_index] not in self.iset_tree[player_index][isets[player_index]]:
                    self.iset_tree[player_index][isets[player_index]][actions[player_index]] = set()
                self.iset_tree[player_index][isets[player_index]][actions[player_index]].add(node)
            return
        elif player == 2:
            for child in node.children:
                self.create_iset_tree_step(child, isets, actions)
        else:
            new_isets = [isets[0], isets[1]]
            new_isets[player] = node.i_set
            new_isets = tuple(new_isets)
            if isets[player] not in self.iset_tree[player]:
                self.iset_tree[player][isets[player]] = {}
            if actions[player] not in self.iset_tree[player][isets[player]]:
                self.iset_tree[player][isets[player]][actions[player]] = set()
            self.iset_tree[player][isets[player]][actions[player]].add(new_isets[player])
            for i, child in enumerate(node.children):
                new_actions = [actions[0], actions[1]]
                new_actions[player] = i
                new_actions = tuple(new_actions)
                self.create_iset_tree_step(child, new_isets, new_actions)

    #### HEAVY COPMUTATION ####

    # regret computation
    def compute_regret(self):
        self.compute_regret_for_player(0)
        self.compute_regret_for_player(1)

    def compute_regret_for_player(self, player):
        for iset in self.i_sets[player]:
            cfv = [0] * len(self.i_sets_to_nodes[player][iset][0].children)
            for node in self.i_sets_to_nodes[player][iset]:
                for i, child in enumerate(node.children):
                    cfv[i] += self.counterfactual_values[child][player]
            value = 0
            for i in range(len(cfv)):
                value += cfv[i] * self.strategy[player][iset][i]
            for i, cfv_value in enumerate(cfv):
                self.counterfactual_regret[player][iset][i] += cfv_value - value
                if self.cfr_plus:
                    self.counterfactual_regret[player][iset][i] = max(0, self.counterfactual_regret[player][iset][i])

    # computing counterfactual values
    def compute_counterfactual_values(self):
        for root in self.subgame.roots:
            reaches = self.subgame.reaches[root]
            cfvs = self.compute_cfvs(root, reaches[0], reaches[1])
            self.counterfactual_values[root] = [0 if reaches[0] == 0 else cfvs[0] / reaches[0],
                                                0 if reaches[1] == 0 else cfvs[1] / reaches[1]]

    def compute_cfvs(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif node.player == 1:
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val
        else:
            p0val = 0
            p1val = 0
            for child, chance in zip(node.children, node.chance):
                temp0, temp1 = self.compute_cfvs(child, reach0 * chance, reach1 * chance)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1
            return p0val, p1val

    def compute_expected(self, node):
        exp_val = 0
        if node.player == 3:
            return node.value 
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                exp_val += self.compute_expected(child) * chance
        else:
            for i, child in enumerate(node.children):
                exp_val += self.compute_expected(child) * self.average_strategy[node.player][node.i_set][i]
        return exp_val

    # strategy computation
    def compute_strategy(self):
        self.compute_strategy_for_player(0)
        self.compute_strategy_for_player(1)

    def compute_strategy_for_player(self, player):
        for iset in self.i_sets[player]:
            sumed = np.sum(np.clip(self.counterfactual_regret[player][iset], 0, np.inf))
            if sumed > 0:
                for i, regret in enumerate(self.counterfactual_regret[player][iset]):
                    self.strategy[player][iset][i] = max(0, regret) / sumed
            else:
                self.set_uniform_strategy(player, iset, self.strategy)

    def average_strategies(self):
        if self.iteration > 1:
            self.avg_visited = []
            for root in self.subgame.roots:
                # self.avg_strat_at_node(root, reaches, reaches)
                self.avg_strat_at_node(root, [1, 1], [1, 1])
        else:
            self.average_strategy = copy.deepcopy(self.strategy)

    def avg_strat_at_node(self, node, player_reach, player_average_reach):
        player = node.player
        if player == 3:
            return
        elif player == 2:
            for child in node.children:
                self.avg_strat_at_node(child, player_reach, player_average_reach)
        else:
            for i, child in enumerate(node.children):
                new_player_reach = [0, 0]
                new_player_average_reach = [0, 0]
                new_player_reach[1 - player] = player_reach[1 - player]
                new_player_reach[player] = player_reach[player] * self.strategy[player][node.i_set][i]

                new_player_average_reach[1 - player] = player_average_reach[1 - player]
                new_player_average_reach[player] = player_average_reach[player] * \
                                                   self.average_strategy[player][node.i_set][i]
                self.avg_strat_at_node(child, new_player_reach, new_player_average_reach)
            if (player, node.i_set) not in self.avg_visited:
                self.avg_visited.append((player, node.i_set))
                self.average_in_is(player, node.i_set, player_reach[player], player_average_reach[player])

    def average_in_is(self, player, iset, reach, avg_reach):
        if reach + avg_reach == 0:
            self.set_uniform_strategy(player, iset, self.average_strategy)
            return
            # reach += 0.001
            # avg_reach += 0.001
        if self.progressive_strategy:
            self.average_strategy[player][iset] = np.divide(
                (np.multiply(self.strategy[player][iset], reach * self.iteration)) + np.multiply(
                    self.average_strategy[player][iset], avg_reach * self.iter_count),
                reach * self.iteration + avg_reach * self.iter_count)
        else:
            self.average_strategy[player][iset] = np.divide(
                np.multiply(self.strategy[player][iset], reach) + np.multiply(
                    self.average_strategy[player][iset], avg_reach * self.iteration),
                reach + avg_reach * self.iteration)

    # ------------------------- RESPONSES PART -------------------------
    # abstract response
    def general_response(self, player, response_function, *args):

        for root in self.subgame.roots:
            reach = self.subgame.reaches[root][player]
            self.compute_reaches(player, root, reach)
        self.compute_response_values_top(player, response_function, *args)

    def compute_reaches(self, player, node, reach):
        self.response_reaches[node] = reach
        if node.player == 3:
            return
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                self.compute_reaches(player, child, reach * chance)
        elif node.player == player:
            for child in node.children:
                self.compute_reaches(player, child, reach)
        else:
            for i, child in enumerate(node.children):
                self.compute_reaches(player, child, reach * self.strategy[1 - player][node.i_set][i])

    def compute_response_values_top(self, player, response_function, *args):
        self.compute_response_values_step(player, response_function, -1, *args)

    def compute_response_values_step(self, player, response_function, iset, *args):
        action_values = [0] * len(self.iset_tree[player][iset])
        for action in self.iset_tree[player][iset]:
            for node_or_set in self.iset_tree[player][iset][action]:
                if isinstance(node_or_set, int):
                    action_values[action] += self.compute_response_values_step(player, response_function, node_or_set,
                                                                               *args)
                else:
                    action_values[action] += (1 if player == 0 else -1) * node_or_set.value * self.response_reaches[
                        node_or_set]
        self.strategy[player][iset] = response_function(action_values, player, iset, *args)
        return np.sum(np.multiply(self.strategy[player][iset], action_values))

    def compute_strategy_in_set_with_function(self, player, response_function, iset):
        self.strategy[player][iset] = response_function(player, iset)

    # quantal response
    def quantal_response(self, player, rationality):
        return self.general_response(player, self.quantal_response_function, rationality)

    @staticmethod
    def quantal_response_function(values, player, iset, rationality=1):
        values = rationality * np.asarray(values)
        exped = np.exp(values - np.max(values))
        return exped / sum(exped)

        # best_response

    def best_response(self, player):
        return self.general_response(player, self.best_response_function)

    @staticmethod
    def best_response_function(values, player, iset):
        strategy_from_values = [0] * len(values)
        strategy_from_values[np.argmax(values)] = 1
        return strategy_from_values

    # linear quantal response
    def linear_response(self, player, rationality):
        return self.general_response(player, self.linear_response_function, rationality)

    @staticmethod
    def linear_response_function(values, player, iset, rationality):
        rationality = rationality / len(values)
        valsum = np.sum(values)
        if len(values) == 1:
            return [1]
        else:
            valmin = np.min(values)
            if rationality == 0:
                constant = 0
            else:
                numval = len(values)
                constant = (valsum - numval * valmin) / (1 / rationality - numval)
            values = values - valmin + constant
            valsum = np.sum(values)
            if valsum == 0:
                values += 1
            valsum = np.sum(values)
            return values / valsum

    ### STRATEGY COMBINATION ###
    def combine_strategies(self, strategy1, strategy2, alpha):
        self.avg_visited = []
        for root in self.subgame.roots:
            reaches = self.subgame.reaches[root]
            self.combine_at_node(root, reaches, reaches, strategy1, strategy2, alpha)

    def combine_at_node(self, node, strategy1_reach, strategy2_reach, strategy1, strategy2, alpha):
        player = node.player
        if player == 3:
            return
        elif player == 2:
            for child in node.children:
                self.combine_at_node(child, strategy1_reach, strategy2_reach, strategy1, strategy2, alpha)
        else:
            for i, child in enumerate(node.children):
                new_s1_reach = [0, 0]
                new_s2_reach = [0, 0]
                new_s1_reach[1 - player] = strategy1_reach[1 - player]
                new_s1_reach[player] = strategy1_reach[player] * strategy1[player][node.i_set][i]

                new_s2_reach[1 - player] = strategy2_reach[1 - player]
                new_s2_reach[player] = strategy2_reach[player] * strategy2[player][node.i_set][i]
                self.combine_at_node(child, new_s1_reach, new_s2_reach, strategy1, strategy2, alpha)
            if (player, node.i_set) not in self.avg_visited:
                self.avg_visited.append((player, node.i_set))
                self.combine_in_is(player, node.i_set, strategy1_reach[player], strategy2_reach[player], strategy1,
                                   strategy2, alpha)

    def combine_in_is(self, player, iset, s1_reach, s2_reach, strategy1, strategy2, alpha):
        if s1_reach + s2_reach == 0:
            self.set_uniform_strategy(player, iset, self.strategy)
        else:
            self.strategy[player][iset] = np.divide(
                np.multiply(strategy1[player][iset], s1_reach * alpha) + np.multiply(
                    strategy2[player][iset], s2_reach * (1 - alpha)),
                s1_reach * alpha + s2_reach * (1 - alpha))

    ####### regret sum computation ######
    def regret_sum(self, player):
        regret_sum = 0
        for iset in self.i_sets[player]:
            regret_sum += np.sum(np.clip(self.counterfactual_regret[player][iset], 0, np.inf))
        return regret_sum
