from SubgameCFR import SubgameCFR
import random
import numpy as np
import copy


class SubgameRNR(SubgameCFR):
    def __init__(self, subgame, strategy, cfr_player):
        super().__init__(subgame, cfr_player)
        self.cfr_player = cfr_player
        self.p = None
        self.fixed_strategy = strategy

    def solve(self, iterations=1000, p=0.5):
        self.p = p
        super().solve(iterations)

    def compute_strategy(self):
        # print(self.average_strategy)
        self.compute_strategy_for_player(self.cfr_player)
        r = random.random()
        if r < self.p:
            self.strategy[1 - self.cfr_player] = copy.deepcopy(self.fixed_strategy)
        else:
            self.best_response(1 - self.cfr_player)

    def modify_counterfactual_values(self):
        # print(self.subgame.reaches)
        # print(self.average_strategy)
        # print(self.strategy)
        # self.strategy = self.average_strategy
        # self.best_response(1-self.cfr_player)
        # print("First CBR", self.strategy)
        print(self.strategy)
        self.compute_counterfactual_values()
        # resulting_cfvs = copy.deepcopy(self.counterfactual_values)
        # print(resulting_cfvs)

        # self.best_response(self.cfr_player)
        # # print("Second CBR", self.strategy)
        # self.compute_counterfactual_values()
        # for node in self.counterfactual_values.keys():
        #     if self.cfr_player == 1:
        #         self.counterfactual_values[node] = resulting_cfvs[node][0], \
        #                                            self.counterfactual_values[node][self.cfr_player]
        #     else:
        #         self.counterfactual_values[node] = self.counterfactual_values[node][self.cfr_player], \
        #                                            resulting_cfvs[node][1]
        # for root in self.subgame.roots:
        #     print(root, self.counterfactual_values[root], end=", ")
        # print()

        # self.strategy = self.average_strategy
        # self.compute_counterfactual_values()
