import numpy as np
import torch
import ipdb
from copy import deepcopy
import cProfile
class Node:
    def __init__(self, assignments, domains, constraints, parent, prob):
        self.children = {}
        self.parent = parent   
        self.n_visits = 0
        self.Q = 0
        self.u = 0
        self.P = prob

    def expand(self, domain):
        for val in domain:
            self.children[val] = ValNode(self, val)
        
    def select(self, env, c_puct):
        if not self.children:
            return -1, None
        # res =  max(self.children.items(), key=lambda act_node: act_node[1].get_value(c_puct))
        vals = list(self.children.keys())
        
        env.check_assign(self.var, vals)
        
        del_vals = set(self.children.keys()) - set(vals)
        for del_val in del_vals:
            self.children[del_val].delete()
        if not vals:
            return -1, None
        val = np.random.choice(vals)
        return val, self.children[val]
    
    def update(self, leaf_value):
        # Count visit.
        self.n_visits += 1
        # Update Q, a running average of values for all visits.
        # self.Q += 1.0*(leaf_value - self.Q + len(self.children)) / self.n_visits
        # 子节点的Q的平均值
        if not self.children:
            self.Q = 0
        else:
            try:
                self.Q = np.mean([child.Q for child in self.children.values()])
            except:
                ipdb.set_trace()
    
    def update_recursive(self, leaf_value):
        # If it is not root, this node's parent should be updated first.
        if self.parent:
            self.parent.update_recursive(leaf_value)
        self.update(leaf_value)

    def get_value(self, c_puct):
        self.u = (c_puct * np.sqrt(self.parent.n_visits) / (1 + self.n_visits))
        return self.Q + self.u

    def is_leaf(self):
        return self.children == {}
    
    def is_root(self):
        return self.parent is None

    def delete(self):
        del self.parent.children[self.var]
        
