import numpy as np
import copy
import pickle
from DataManipulation import load_from_file
import types
from SubgameCFR import SubgameCFR
from SubgameRNR import SubgameRNR


class CFRD:

    def __init__(self, fname, cfr_player=1):
        self.fname = fname
        self.trunk, self.subgames = self.load_game(fname)
        # this will hold which game has which isets for both players
        self.i_sets = [[], []]
        # this holds which nodes belong to each iset
        self.i_sets_to_nodes = [{}, {}]
        # counterfactual values in node
        self.counterfactual_values = {}
        self.counterfactual_regret = [{}, {}]
        # strategies
        self.strategy = [{}, {}]
        self.average_strategy = None
        # response buffers for values, reaches and completed isets
        self.iset_tree = [{}, {}]
        # saving prograssion
        self.save_strategy = False
        self.save_progression = False
        # verbose flag
        self.verbose = 0
        # holding actual game value based on both player strategies
        self.game_value = 0
        # subgame node reaches
        self.reaches = {}
        self.response_reaches = {}
        # subgame iterations
        self.subgame_iterations = 0
        self.subgame_root_counterfactual_values = {}
        self.average_root_counterfactual_values = {}
        # iteration tracking
        self.iter_count = 0
        self.iteration = 0
        # parameters (right now unsetable)
        self.cfr_plus = True
        self.progressive_strategy = False
        self.result_from_average_strategy = True
        # average strategy helper
        self.avg_visited = []
        self.average_from = 0
        self.cfr_player = cfr_player
        self.fixed_strategy = None

    def solve(self, iterations=1000, verbose=0, save_progression=False, save_strategy=False, subgame_iterations=100,
              average_from=1, fixed_strategy=None):
        self.fixed_strategy = fixed_strategy
        self.average_from = average_from
        self.save_progression = save_progression
        self.save_strategy = save_strategy
        self.verbose = verbose
        self.subgame_iterations = subgame_iterations
        self.initialize()
        for self.iteration in range(1, iterations + 1):
            self.iter_count += self.iteration
            self.perform_iteration()
            # print("Main strategy", self.strategy)
            # print("Cumulated main regret", self.counterfactual_regret)
            # print("CFVS at the roots", self.subgame_root_counterfactual_values)
        if self.result_from_average_strategy:
            backup = self.strategy
            self.strategy = self.average_strategy
            self.compute_counterfactual_values()
            self.strategy = backup
        return self.counterfactual_values[self.trunk.root]

    def solve_subgame(self, subgame):
        subgame.change_reaches(self.reaches)
        subgame_cfr = SubgameCFR(subgame, cfr_player=self.cfr_player)
        subgame_cfr.solve(self.subgame_iterations)
        for root in subgame_cfr.subgame.roots:
            self.subgame_root_counterfactual_values[root] = subgame_cfr.counterfactual_values[root]
            self.average_root_counterfactual_values[root][0] += subgame_cfr.counterfactual_values[root][0]
            self.average_root_counterfactual_values[root][1] += subgame_cfr.counterfactual_values[root][1]

    def compute_reaches(self):
        self.compute_reaches_step(self.trunk.root, p1_reach=1, p2_reach=1)

    def compute_reaches_step(self, node, p1_reach, p2_reach):
        if node.player == 3:
            return
        else:
            if len(node.children) == 0:
                self.reaches[node] = (p1_reach, p2_reach)
            elif node.player == 2:
                for chance, child in zip(node.chance, node.children):
                    self.compute_reaches_step(child, p1_reach * chance, p2_reach * chance)
            elif node.player == 1:
                for i, child in enumerate(node.children):
                    self.compute_reaches_step(child, p1_reach * self.strategy[1][node.i_set][i], p2_reach)
            elif node.player == 0:
                for i, child in enumerate(node.children):
                    self.compute_reaches_step(child, p1_reach, p2_reach * self.strategy[0][node.i_set][i])

    # regret computation
    def compute_regret(self):
        self.compute_regret_for_player(0)
        self.compute_regret_for_player(1)

    def compute_regret_for_player(self, player):
        for iset in self.i_sets[player]:
            # print("Trunk regrets for iset: ", iset)
            cfv = [0] * len(self.i_sets_to_nodes[player][iset][0].children)
            for node in self.i_sets_to_nodes[player][iset]:
                # print("Before:", cfv)
                for i, child in enumerate(node.children):
                    cfv[i] += self.counterfactual_values[child][player]
                # print("After", cfv)
            value = 0
            for i in range(len(cfv)):
                value += cfv[i] * self.strategy[player][iset][i]
            # print(cfv, value)
            for i, cfv_value in enumerate(cfv):
                if self.cfr_plus:
                    self.counterfactual_regret[player][iset][i] += max(0, cfv_value - value)
                else:
                    self.counterfactual_regret[player][iset][i] += cfv_value - value

    # strategy computation
    def compute_strategy(self):
        self.compute_strategy_for_player(0)
        self.compute_strategy_for_player(1)

    def compute_strategy_for_player(self, player):
        for iset in self.i_sets[player]:
            if self.fixed_strategy is not None and iset in self.fixed_strategy[player]:
                continue
            sumed = np.sum(np.clip(self.counterfactual_regret[player][iset], 0, np.inf))
            if sumed > 0:
                for i, regret in enumerate(self.counterfactual_regret[player][iset]):
                    self.strategy[player][iset][i] = max(0, regret) / sumed
            else:
                self.set_uniform_strategy(player, iset, self.strategy)

    def average_strategies(self):
        if self.iteration > self.average_from:
            self.avg_visited = []
            self.avg_strat_at_node(self.trunk.root, [1, 1], [1, 1])
        elif self.iteration == self.average_from:
            self.average_strategy = copy.deepcopy(self.strategy)

    def avg_strat_at_node(self, node, player_reach, player_average_reach):
        player = node.player
        if player == 3:
            return
        elif player == 2:
            for child in node.children:
                self.avg_strat_at_node(child, player_reach, player_average_reach)
        else:
            if len(node.children) == 0:
                return
            for i, child in enumerate(node.children):
                new_player_reach = [0, 0]
                new_player_average_reach = [0, 0]
                new_player_reach[1 - player] = player_reach[1 - player]
                new_player_reach[player] = player_reach[player] * self.strategy[player][node.i_set][i]

                new_player_average_reach[1 - player] = player_average_reach[1 - player]
                new_player_average_reach[player] = player_average_reach[player] * \
                                                   self.average_strategy[player][node.i_set][i]
                self.avg_strat_at_node(child, new_player_reach, new_player_average_reach)
            if (player, node.i_set) not in self.avg_visited:
                self.avg_visited.append((player, node.i_set))
                self.average_in_is(player, node.i_set, player_reach[player], player_average_reach[player])

    def average_in_is(self, player, iset, reach, avg_reach):
        if reach + avg_reach == 0:
            self.set_uniform_strategy(player, iset, self.average_strategy)
        else:
            if self.progressive_strategy:
                # self.average_strategy[player][iset] = np.divide(
                #     (np.multiply(self.strategy[player][iset], reach * self.iteration * self.iteration)) + np.multiply(
                #         self.average_strategy[player][iset], avg_reach * self.iter_count),
                #     reach * self.iteration * self.iteration + avg_reach * self.iter_count)
                self.average_strategy[player][iset] = np.divide(
                    (np.multiply(self.strategy[player][iset], reach * self.iteration)) + np.multiply(
                        self.average_strategy[player][iset], avg_reach * self.iter_count),
                    reach * self.iteration + avg_reach * self.iter_count)
            else:
                self.average_strategy[player][iset] = np.divide(
                    np.multiply(self.strategy[player][iset], reach) + np.multiply(
                        self.average_strategy[player][iset], avg_reach * (self.iteration - self.average_from)),
                    reach + avg_reach * (self.iteration - self.average_from))

    def print_statistics(self):
        if self.verbose > 0:
            print(self.iteration)
        # print(self.counterfactual_values[self.trunk.root])
        pass

    def perform_iteration(self):
        self.compute_reaches()
        for subgame in self.subgames:
            self.solve_subgame(subgame)
        self.compute_counterfactual_values()
        self.compute_regret()
        self.compute_strategy()
        self.average_strategies()
        self.print_statistics()

    def load_game(self, fname):
        trunk = load_from_file(fname + "_trunk.trunk")
        subgames = load_from_file(fname + "_subgames.sg")
        return trunk, subgames

    def initialize(self):
        self.create_isets()
        self.create_iset_tree()
        self.initialize_strategy()
        self.initialize_regrets()
        self.initialize_structures()
        self.iter_count = 0

    def initialize_structures(self):
        for subgame in self.subgames:
            for root in subgame.roots:
                self.subgame_root_counterfactual_values[root] = [0, 0]
                self.average_root_counterfactual_values[root] = [0, 0]

    def create_isets(self):
        # create empty array for isets in trunk for both players
        self.put_to_isets(self.trunk.root, depth=0)

    def put_to_isets(self, node, depth):
        # gather variables from the node for better understanding
        player = node.player
        if player > 2:
            return
        if player == 2:
            for child in node.children:
                self.put_to_isets(child, depth=depth + 1)
        else:
            if len(node.children) == 0:
                return
            iset = node.i_set
            # put node to iset and create new iset list if it does not exist yet
            if iset not in self.i_sets_to_nodes[player]:
                self.i_sets_to_nodes[player][iset] = []
            self.i_sets_to_nodes[player][iset].append(node)
            # put iset to subgame
            if iset not in self.i_sets[player]:
                self.i_sets[player].append(iset)
            for child in node.children:
                self.put_to_isets(child, depth=depth + 1)

    # strategy initialization
    def initialize_random_strategy(self):
        self.initialize_strategy_randomly(player=1)
        self.initialize_strategy_randomly(player=0)
        self.average_strategy = copy.deepcopy(self.strategy)

    def initialize_strategy_randomly(self, player):
        for iset in self.i_sets[player]:
            self.set_random_strategy(player=player, iset=iset, strategy=self.strategy)

    def set_random_strategy(self, player, iset, strategy):
        sample_node = self.i_sets_to_nodes[player][iset][0]
        n_children = len(sample_node.children)
        if n_children == 0:
            return
        strat = np.random.rand(n_children)
        strategy[player][iset] = strat / np.sum(strat)

    def initialize_strategy(self):
        self.initialize_strategy_uniformly(player=1)
        self.initialize_strategy_uniformly(player=0)
        self.average_strategy = copy.deepcopy(self.strategy)
        if self.fixed_strategy is not None:
            for player in range(2):
                for iset in self.i_sets[player]:
                    if iset in self.fixed_strategy[player]:
                        self.strategy[player][iset] = self.fixed_strategy[player][iset]

    def initialize_strategy_uniformly(self, player):
        for iset in self.i_sets[player]:
            self.set_uniform_strategy(player=player, iset=iset, strategy=self.strategy)

    def set_uniform_strategy(self, player, iset, strategy):
        sample_node = self.i_sets_to_nodes[player][iset][0]
        n_children = len(sample_node.children)
        if n_children == 0:
            return
        strategy[player][iset] = [x / n_children for x in [1] * n_children]

    # regret initialization
    def initialize_regrets(self):
        self.initialize_regret_for_player(0)
        self.initialize_regret_for_player(1)

    def initialize_regret_for_player(self, player):
        for iset in self.strategy[player]:
            self.counterfactual_regret[player][iset] = [0] * len(self.strategy[player][iset])

    # create iset tree
    def create_iset_tree(self):
        self.create_iset_tree_step(self.trunk.root, (-1, -1), (0, 0))
        # for i in range(2):
        #     self.modify_iset_tree(i, -1)

    def modify_iset_tree(self, player, key):
        if not isinstance(key, int):
            return True
        if key not in self.iset_tree[player]:
            return False
        else:
            for action in self.iset_tree[player][key]:
                for new_key in self.iset_tree[player][key][action]:
                    if not self.modify_iset_tree(player, new_key):
                        self.iset_tree[player][key][action].remove(new_key)
                        self.iset_tree[player][key][action].add((new_key, 0))
            return True

    def create_iset_tree_step(self, node, isets, actions):
        player = node.player
        if player == 3:
            for player_index in range(2):
                if isets[player_index] not in self.iset_tree[player_index]:
                    self.iset_tree[player_index][isets[player_index]] = {}
                if actions[player_index] not in self.iset_tree[player_index][isets[player_index]]:
                    self.iset_tree[player_index][isets[player_index]][actions[player_index]] = set()
                self.iset_tree[player_index][isets[player_index]][actions[player_index]].add(node)
            return
        elif player == 2:
            if len(node.children) == 0:
                for player_index in range(2):
                    if isets[player_index] not in self.iset_tree[player_index]:
                        self.iset_tree[player_index][isets[player_index]] = {}
                    if actions[player_index] not in self.iset_tree[player_index][isets[player_index]]:
                        self.iset_tree[player_index][isets[player_index]][actions[player_index]] = set()
                    self.iset_tree[player_index][isets[player_index]][actions[player_index]].add(node)
            for child in node.children:
                self.create_iset_tree_step(child, isets, actions)
        else:
            new_isets = [isets[0], isets[1]]
            new_isets[player] = node.i_set
            new_isets = tuple(new_isets)
            if isets[player] not in self.iset_tree[player]:
                self.iset_tree[player][isets[player]] = {}
            if actions[player] not in self.iset_tree[player][isets[player]]:
                self.iset_tree[player][isets[player]][actions[player]] = set()
            if len(node.children) == 0:
                for player_index in range(2):
                    if isets[player_index] not in self.iset_tree[player_index]:
                        self.iset_tree[player_index][isets[player_index]] = {}
                    if actions[player_index] not in self.iset_tree[player_index][isets[player_index]]:
                        self.iset_tree[player_index][isets[player_index]][actions[player_index]] = set()
                    self.iset_tree[player_index][isets[player_index]][actions[player_index]].add(node)
            else:
                self.iset_tree[player][isets[player]][actions[player]].add(new_isets[player])

            for i, child in enumerate(node.children):
                new_actions = [actions[0], actions[1]]
                new_actions[player] = i
                new_actions = tuple(new_actions)
                self.create_iset_tree_step(child, new_isets, new_actions)

    def compute_counterfactual_values(self):
        self.counterfactual_values[self.trunk.root] = self.compute_cfvs(self.trunk.root, 1, 1)

    def compute_cfvs(self, node, reach0, reach1):
        if node.player == 3:
            self.counterfactual_values[node] = (node.value * reach0, -node.value * reach1)
            return node.value * reach0, -node.value * reach1
        elif len(node.children) == 0:
            cfvs = self.subgame_root_counterfactual_values[node]
            self.counterfactual_values[node] = [cfvs[0] * reach0, cfvs[1] * reach1]
            return cfvs[0] * reach0, cfvs[1] * reach1
        elif node.player == 2:
            p0val = 0
            p1val = 0
            for child, chance in zip(node.children, node.chance):
                temp0, temp1 = self.compute_cfvs(child, reach0 * chance, reach1 * chance)
                p0val += temp0
                p1val += temp1
                self.counterfactual_values[child] = (temp0, temp1)
            return p0val, p1val
        elif node.player == 1:
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs(child, reach0 * self.strategy[1][node.i_set][i], reach1)
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0
                p1val += temp1 * self.strategy[1][node.i_set][i]
            return p0val, p1val
        elif node.player == 0:
            p0val = 0
            p1val = 0
            for i, child in enumerate(node.children):
                temp0, temp1 = self.compute_cfvs(child, reach0, reach1 * self.strategy[0][node.i_set][i])
                self.counterfactual_values[child] = (temp0, temp1)
                p0val += temp0 * self.strategy[0][node.i_set][i]
                p1val += temp1
            return p0val, p1val

    # ------------------------- RESPONSES PART -------------------------
    # abstract response
    def general_response(self, player, response_function, *args):
        self.compute_response_reaches(player, self.trunk.root, 1)
        self.compute_counterfactual_values()
        self.compute_response_values_top(player, response_function, *args)

    def compute_response_reaches(self, player, node, reach):
        self.response_reaches[node] = reach
        if node.player == 3:
            return
        elif node.player == 2:
            for chance, child in zip(node.chance, node.children):
                self.compute_response_reaches(player, child, reach * chance)
        elif node.player == player:
            for child in node.children:
                self.compute_response_reaches(player, child, reach)
        else:
            for i, child in enumerate(node.children):
                self.compute_response_reaches(player, child, reach * self.strategy[1 - player][node.i_set][i])

    def compute_response_values_top(self, player, response_function, *args):
        self.compute_response_values_step(player, response_function, -1, *args)

    def compute_response_values_step(self, player, response_function, iset, *args):
        action_values = [0] * len(self.iset_tree[player][iset])
        for action in self.iset_tree[player][iset]:
            for node_or_set in self.iset_tree[player][iset][action]:
                if isinstance(node_or_set, int):
                    action_values[action] += self.compute_response_values_step(player, response_function, node_or_set,
                                                                               *args)
                else:
                    action_values[action] += self.counterfactual_values[node_or_set][player]
        self.strategy[player][iset] = response_function(action_values, player, iset, *args)
        # print("Action values:", action_values)
        return np.sum(np.multiply(self.strategy[player][iset], action_values))

    def compute_strategy_in_set_with_function(self, player, response_function, iset):
        self.strategy[player][iset] = response_function(player, iset)

    # quantal response
    def quantal_response(self, player, rationality):
        return self.general_response(player, self.quantal_response_function, rationality)

    @staticmethod
    def quantal_response_function(values, player, iset, rationality=1):
        values = rationality * np.asarray(values)
        exped = np.exp(values - np.max(values))
        return exped / sum(exped)

    # best_response
    def best_response(self, player):
        return self.general_response(player, self.best_response_function)

    @staticmethod
    def best_response_function(values, player, iset):
        strategy_from_values = np.asarray([0.0] * len(values))
        strategy_from_values[np.argmax(values)] = 1
        # strategy_from_values += 0.001
        # strategy_from_values = strategy_from_values / np.sum(strategy_from_values)
        return strategy_from_values

    # linear quantal response
    def linear_response(self, player, rationality):
        return self.general_response(player, self.linear_response_function, rationality)

    @staticmethod
    def linear_response_function(values, player, iset, rationality):
        rationality = rationality / len(values)
        valsum = np.sum(values)
        if len(values) == 1:
            return [1]
        else:
            valmin = np.min(values)
            if rationality == 0:
                constant = 0
            else:
                numval = len(values)
                constant = (valsum - numval * valmin) / (1 / rationality - numval)
            values = values - valmin + constant
            valsum = np.sum(values)
            if valsum == 0:
                values += 1
            valsum = np.sum(values)
            return values / valsum
