from CFRD import CFRD
from SubgameRNR import SubgameRNR
import random
import copy
import Combination
from SubgameCFR import SubgameCFR


class RNRD(CFRD):
    def __init__(self, fname, p, strategy, cfr_player):
        super().__init__(fname)
        self.p = p
        self.fixed_strategy = strategy
        self.cfr_player = cfr_player
        self.game = self.trunk
        self.update_player = True

    def solve_subgame(self, subgame, return_strategy=False):
        subgame.change_reaches(self.reaches)
        subgame_cfr = SubgameRNR(subgame, self.fixed_strategy, self.cfr_player)
        # subgame_cfr = SubgameCFR(subgame)
        subgame_cfr.solve(self.subgame_iterations, self.p)
        # subgame_cfr.solve(self.subgame_iterations)
        for root in subgame_cfr.subgame.roots:
            if self.iteration < 1:
                values = subgame_cfr.counterfactual_values[root]
                self.subgame_root_counterfactual_values[root][self.cfr_player] = values[self.cfr_player]
                self.subgame_root_counterfactual_values[root][1 - self.cfr_player] = (subgame_cfr.counterfactual_values[
                                                                                          root][
                                                                                          1 - self.cfr_player] * self.iteration +
                                                                                      self.subgame_root_counterfactual_values[
                                                                                          root][1 - self.cfr_player
                                                                                                ] * self.iter_count) / (
                                                                                             self.iteration + self.iter_count)
            else:
                self.subgame_root_counterfactual_values[root] = list(subgame_cfr.counterfactual_values[root])
            self.average_root_counterfactual_values[root][0] += subgame_cfr.counterfactual_values[root][0]
            self.average_root_counterfactual_values[root][1] += subgame_cfr.counterfactual_values[root][1]
        if return_strategy:
            return subgame_cfr.average_strategy

    def perform_iteration(self):
        #### Normal
        self.print_statistics()

        print("Strategy at the start of iteration", self.strategy)
        self.update_player = True
        self.compute_reaches()
        for subgame in self.subgames:
            self.solve_subgame(subgame)
        self.compute_counterfactual_values()
        self.compute_regret()
        self.compute_strategy()
        # self.average_strategies()
        print("Strategy changed for BR", self.strategy)

        self.update_player = False
        # self.compute_reaches()
        # for subgame in self.subgames:
        #     self.solve_subgame(subgame)
        self.compute_counterfactual_values()
        # self.compute_regret()
        self.compute_strategy()
        self.average_strategies()
        #
        # print("Strategy at the end of iteration", self.strategy)

        #### Delayed
        # self.print_statistics()
        #
        # print("Strategy at the start of iteration", self.strategy)
        #
        # self.compute_reaches()
        # for subgame in self.subgames:
        #     self.solve_subgame(subgame)
        # self.compute_counterfactual_values()
        # self.compute_regret()
        # self.update_player = False
        # self.compute_strategy()
        # self.update_player = True
        # self.compute_strategy()
        # self.average_strategies()
        #
        # print("Strategy at the end of iteration", self.strategy)

    def compute_regret(self):
        if self.update_player:
            self.compute_regret_for_player(self.cfr_player)

    def compute_strategy(self):
        if self.update_player:
            self.compute_strategy_for_player(self.cfr_player)
        else:
            # oppstrat1 = self.fixed_strategy
            # self.best_response(1 - self.cfr_player)
            # strat2 = self.strategy
            # strat1 = [strat2[0], oppstrat1] if self.cfr_player == 0 else [oppstrat1, strat2[1]]
            # print(strat2)
            # combine = Combination.Combination(strat1, strat2, self)
            # self.strategy = combine.combine_strategies(self.p)
            r = random.random()
            if r < self.p:
                self.strategy[1 - self.cfr_player] = copy.deepcopy(self.fixed_strategy)
            else:
                self.best_response(1 - self.cfr_player)
        # print(self.strategy)
        # print(self.average_strategy[1])
        # print(self.counterfactual_regret[1])
        # print(self.subgame_root_counterfactual_values)

    def reconstruct_strategy(self, player):
        backup = copy.deepcopy(self.strategy)
        self.strategy = self.average_strategy
        self.compute_reaches()
        self.subgame_iterations = 1000
        for subgame in self.subgames:
            strategy = self.solve_subgame(subgame, return_strategy=True)
            for key in strategy[player]:
                if key != -1:
                    self.strategy[player][key] = strategy[player][key]
        ret_strategy = self.strategy
        self.strategy = backup
        return ret_strategy
