from ExtensiveGame import ExtensiveGame
from ExtensiveGame import Node
import numpy as np

class ExtensiveSubgame(ExtensiveGame):

    def __init__(self, roots):
        super().__init__()
        self.roots = roots
        self.reaches = None
        self.counterfactual_values = None
        self.actual_isets = [0, 0]
        self.safe_reaches = None

    def change_reaches(self, reaches):
        self.reaches = reaches

    def change_safe_reaches(self, reaches):
        self.safe_reaches = reaches

    def change_counterfactual_values(self, counterfactual_values):
        self.counterfactual_values = counterfactual_values

    def join_with_chance_node(self):
        self.join_with_chance_node([1/len(self.roots) for i in self.roots])

    def join_with_chance_node(self, chances):
        self.root = Node(2, 0, None, self.getid())
        # assert sum(chances) !=0
        for root, chance in zip(self.roots, chances):
            self.root.children.append(root)
            # Think about normalizing here, I am not sure if it could hapen that sum of chances is 0
            self.root.chance.append(chance)


    def gadget_game(self):
        self.actual_isets = [100, 100]
        self.root = Node(2, 0, None, self.getid())
        reach_sum = 0
        average_cfv = []
        for root in self.roots:
            if self.safe_reaches is not None:
                reaches = self.safe_reaches[root]
                reach_sum += reaches[0] * reaches[1] * reaches[2]
                average_cfv.append(self.counterfactual_values[root][0])
        average_cfv = np.average(average_cfv)
        for root in self.roots:
            cfv = [0, 0]
            if self.counterfactual_values is not None:
                cfv = self.counterfactual_values[root]
            next_iset = self.actual_isets[0]
            self.actual_isets[0] += 1
            follow_choice = Node(0, next_iset, self.root, self.getid())
            follow_choice.children.append(root)

            terminate_node = Node(3, 0, follow_choice, self.getid(), average_cfv)
            follow_choice.children.append(terminate_node)

            self.root.children.append(follow_choice)
            if self.safe_reaches is not None:
                reaches = self.safe_reaches[root]
                self.root.chance.append((reaches[0] * reaches[1] * reaches[2]) / reach_sum)
            else:
                self.root.chance.append(1 / len(self.roots))


    def print_tree(self):
        for root in self.roots:
            self.root = root
            super().print_tree()
