class Node:
    def __init__(self, context, score, parent=None, length_penalty=1, past_key_values=None):
        self.context = context
        self.score = score
        self.length_penalty = length_penalty
        self.parent = parent
        self.children = []
        self.depth = 0 if parent is None else parent.depth + 1
        self.past_key_values = past_key_values
    
    def add_child(self, child):
        self.children.append(child)

    def path_length(self):
        return self.depth
    
    def print(self, tokenizer, level=0):
        # Print tree structure
        text = tokenizer.decode(self.context[0], skip_special_tokens=True)
        print("  " * level + f"[{self._normalised_score():.2f}] {text}. Length: {self.path_length()}")
        for child in self.children:
            child.print(tokenizer, level + 1)
            
    def _normalised_score(self):
        return self.score / (self.path_length() ** self.length_penalty) if self.path_length() > 0 else float("-inf")
            
    def _best_n_paths(self, n):
        """
        Returns a list of up to n best paths from this node.
        Each path is a list of nodes from root to leaf,
        sorted by descending average normalized score.
        """

        if not self.children:
            # Leaf node: return itself as a single path
            return [([self], self._normalised_score())]

        all_paths = []
        for child in self.children:
            child_paths = child._best_n_paths(n)
            for path, avg_score in child_paths:
                all_paths.append(([self] + path, avg_score))

        # Sort all paths by descending avg_score and take top n
        all_paths.sort(key=lambda x: x[1], reverse=True)
        return all_paths[:n]


    def best_path(self, eos_token_id=None):
        return self.best_n_paths(1, eos_token_id=eos_token_id)[0]
    
    def mean_branching_factor(self):
        
        def traverse(node):
            if not node.children:
                return 0, 0
            
            total_nodes = 1
            total_branches = len(node.children)
            
            for child in node.children:
                child_nodes, child_branches = traverse(child)
                total_nodes += child_nodes
                total_branches += child_branches
            
            return total_nodes, total_branches
        
        total_nodes, total_branches = traverse(self)
        return total_branches / total_nodes if total_nodes > 0 else 0


    def _best_n_paths_with_eos_priority(self, n, eos_token_id):
        # Get all candidate paths
        number_paths = self.total_paths()
        all_paths = self._best_n_paths(number_paths)  # Get more than needed to have enough candidates

        eos_paths = []
        non_eos_paths = []

        for path, score in all_paths:
            last_node = path[-1]
            last_token = last_node.context[0][-1].item()
            if last_token == eos_token_id:
                eos_paths.append((path, score))
            else:
                non_eos_paths.append((path, score))

        # Sort each list by descending score (just to be safe)
        eos_paths.sort(key=lambda x: x[1], reverse=True)
        non_eos_paths.sort(key=lambda x: x[1], reverse=True)

        # Take up to n from eos_paths, then fill with non_eos_paths if needed
        selected_paths = eos_paths[:n]
        if len(selected_paths) < n:
            selected_paths += non_eos_paths[: (n - len(selected_paths))]

        return selected_paths

    def best_n_paths(self, n, eos_token_id=None):
        if eos_token_id is None:
            # fallback to original method
            return [path for (path, score) in self._best_n_paths(n)]
        else:
            paths_with_scores = self._best_n_paths_with_eos_priority(n, eos_token_id)
            return [path for (path, score) in paths_with_scores]
        
    def total_branches(self):
        count = 0
        stack = [self]

        while stack:
            node = stack.pop()
            count += len(node.children)
            stack.extend(node.children)

        return count
    
        
    def total_paths(self):
        count = 0
        stack = [self]

        while stack:
            node = stack.pop()
            if not node.children:
                # Leaf node, count as one path
                count += 1
            else:
                # Add all children to stack for further traversal
                stack.extend(node.children)

        return count
    
    def prune_incomplete(self, eos_token_id: int, max_new_tokens: int) -> bool:
        """
        Recursively prunes any subtrees that do not end in a completed path.
        A completed path ends in EOS or reaches max length.
        Returns True if this node (or any of its children) leads to a complete path.
        """
        # If this is a leaf
        if not self.children:
            last_token = self.context[0][-1].item()
            is_complete = (last_token == eos_token_id) or (self.path_length() >= max_new_tokens)
            return is_complete

        pruned_children = []

        for child in self.children:
            keep = child.prune_incomplete(eos_token_id, max_new_tokens)
            if keep:
                pruned_children.append(child)
            else:
                # Explicitly break references for GC
                child.parent = None
                child.children = []
                del child

        self.children = pruned_children

        # Return whether this node leads to any valid paths
        return len(self.children) > 0

    def __lt__(self, other):
        return self._normalised_score() < other._normalised_score()
