from CFR import CFR
import random
import copy
import numpy as np


class RNR(CFR):
    def __init__(self, fname, strategy=None, cfr_player=1, is_opponent_cfr=False):
        super().__init__(fname, cfr_plus=False)
        self.cfr_player = cfr_player
        self.progressive_strategy = False
        self.p = None
        self.fixed_strategy = strategy
        self.is_opponent_cfr = is_opponent_cfr
        self.fixed_update = False
        if strategy is None:
            self.fixed_strategy = self.qrqr_strategy(fname)

    def qrqr_strategy(self, fname):
        qrqr = QRQR(fname)
        qrqr.solve()
        return qrqr.strategy[1 - self.cfr_player]

    def solve(self, iterations=1000, verbose=0, save_progression=False, save_strategy=False, skip=1, p=0.5, fixed=None):
        self.p = p
        super().solve(iterations=iterations, verbose=verbose, save_progression=save_progression,
                      save_strategy=save_strategy, skip=skip, fixed=fixed)
        # print(self.counterfactual_values[self.game.root])

    def normalize_average_strategy_top(self, player):
        self.avg_visited = []
        self.normalize_average_strategy(self.game.root, player)

    def normalize_average_strategy(self, node, player):
        player = node.player
        if player == 3:
            return
        elif player == 2:
            for child in node.children:
                self.normalize_average_strategy(child, player)
        else:
            for child in node.children:
                self.normalize_average_strategy(child, player)
            if (player, node.i_set) not in self.avg_visited:
                self.avg_visited.append((player, node.i_set))
                self.normalize_in_iset(node.i_set, player)

    def normalize_in_iset(self, iset, player):
        strategy_sum = np.sum(self.average_strategy[player][iset])
        for i in range(len(self.average_strategy[player][iset])):
            self.average_strategy[player][iset][i] *= strategy_sum

    def average_strategies(self):
        if False:
            if self.iteration > 0:
                self.avg_visited = []
                update_value = self.p
                if not self.fixed_update:
                    update_value = 1 - self.p
                self.avg_strat_at_node(self.game.root, [1, update_value], [1, update_value])
            else:
                self.average_strategy = copy.deepcopy(self.strategy)
        else:
            super().average_strategies()

    def compute_strategy(self):
        # print(self.average_strategy)
        self.compute_strategy_for_player(self.cfr_player)

        if self.fixed_update:
            self.strategy[1 - self.cfr_player] = copy.deepcopy(self.fixed_strategy)
        else:
            if self.is_opponent_cfr:
                self.compute_strategy_for_player(1 - self.cfr_player)
            else:
                self.strategy = self.best_response(1 - self.cfr_player, self.strategy)[1]

    def compute_regret(self):
        r = random.random()
        self.fixed_update = True
        if r >= self.p:
            self.fixed_update = False
        if self.is_opponent_cfr:
            self.compute_regret_for_player(self.cfr_player)
            # self.compute_regret_for_player_mccfr(self.cfr_player, self.fixed_update)
            if not self.fixed_update:
                self.compute_strategy_for_player(1 - self.cfr_player)
                self.compute_counterfactual_values()
                self.compute_regret_for_player(1 - self.cfr_player)
        else:
            self.compute_regret_for_player(self.cfr_player)
        # print("Regret", self.counterfactual_regret)
        # print("Strategy", self.strategy)

    def compute_regret_for_player_mccfr(self, player, fixed_update):
        for iset in self.i_sets[player]:
            # print("Iset", iset)
            cfv = [0] * len(self.i_sets_to_nodes[player][iset][0].children)
            for node in self.i_sets_to_nodes[player][iset]:
                for i, child in enumerate(node.children):
                    cfv[i] += self.counterfactual_values[child][player]
                    # print(self.counterfactual_values[child][player], child)
            value = 0
            for i in range(len(cfv)):
                value += cfv[i] * self.strategy[player][iset][i]
            # print(player, cfv, value)
            for i, cfv_value in enumerate(cfv):
                self.counterfactual_regret[player][iset][i] += (self.p if self.fixed_update else 1 - self.p) * (
                        cfv_value - value)
                if self.cfr_plus:
                    self.counterfactual_regret[player][iset][i] = max(0, self.counterfactual_regret[player][iset][i])

    def print_responses(self):
        print()
        responses = [self.best_response(0, self.average_strategy)[0], self.best_response(0, self.strategy)[0], [], []]
        print("BR:", self.best_response(1 - self.cfr_player, self.average_strategy)[0], end=" ")
        print("BRc:", self.best_response(1 - self.cfr_player, self.strategy)[0], end=" ")
        return responses

    def against_fixed(self, strategy):
        backup_strategy = copy.deepcopy(self.strategy)
        self.strategy[self.cfr_player] = copy.deepcopy(strategy[self.cfr_player])
        self.strategy[1 - self.cfr_player] = copy.deepcopy(self.fixed_strategy)
        self.compute_game_value(self.strategy)
        self.strategy = backup_strategy
        return self.game_value
