import pickle
from os import PathLike
import math
from copy import deepcopy
from typing import Generic, Optional, NamedTuple, Callable, Hashable
import itertools
from abc import ABC
from abc import ABC
from collections import defaultdict
import random
import numpy as np
import json
from tqdm import trange
import os
import datetime

from .. import SearchAlgorithm, WorldModel, SearchConfig, State, Action, Example, Trace

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

class PlanUNode(Generic[State, Action, Example]):
    id_iter = itertools.count()

    @classmethod
    def reset_id(cls):
        cls.id_iter = itertools.count()

    def __init__(self, state: Optional[State], action: Optional[Action], parent: "Optional[PlanUNode]" = None,
                 fast_reward: float = 0., fast_reward_details=None,
                 is_terminal: bool = False, calc_q: Callable[[list[float]], float] = np.mean,
                 n_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0):
        self.id = next(PlanUNode.id_iter)
        if fast_reward_details is None:
            fast_reward_details = {}
        self.cum_rewards: list[float] = []
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.quantile_probs = np.linspace(1/(2*n_atoms), 1-1/(2*n_atoms), n_atoms)
        self.quantile_values = np.linspace(v_min, v_max, n_atoms)
        self.distribution_history = []

        if fast_reward != 0:
            self.quantile_values = np.ones(n_atoms) * fast_reward
            self.distribution_history.append(self.quantile_values.copy())
            
        self.fast_reward = self.reward = fast_reward
        self.fast_reward_details = fast_reward_details
        self.is_terminal = is_terminal
        self.action = action
        self.state = state
        self.parent = parent
        self.children: 'Optional[list[PlanUNode]]' = None
        self.calc_q = calc_q
        if parent is None:
            self.depth = 0
        else:
            self.depth = parent.depth + 1

    @property
    def Q(self) -> float:
        if self.state is None:
            return self.fast_reward
        elif len(self.cum_rewards) == 0:
            return np.mean(self.quantile_values)
        else:
            return self.calc_q(self.cum_rewards)
    
    def get_distorted_q(self, distortion_fn: Callable[[np.ndarray, np.ndarray], float]) -> float:
        if self.state is None:
            return self.fast_reward
        return distortion_fn(self.quantile_probs, self.quantile_values)
    
    def get_distribution_summary(self) -> dict:
        if len(self.quantile_values) == 0:
            return {
                "mean": 0.0,
                "std": 0.0,
                "min": self.v_min,
                "max": self.v_max,
                "median": 0.0,
                "mode": 0.0,
                "quantiles": []
            }
            
        mean = np.mean(self.quantile_values)
        std = np.std(self.quantile_values)
        
        median_idx = np.argmin(np.abs(self.quantile_probs - 0.5))
        median = self.quantile_values[median_idx]

        try:
            from scipy import stats
            kde = stats.gaussian_kde(self.quantile_values)
            x = np.linspace(min(self.quantile_values), max(self.quantile_values), 1000)
            mode = x[np.argmax(kde(x))]
        except:
            mode = np.median(self.quantile_values)
        
        return {
            "mean": float(mean),
            "std": float(std),
            "min": float(np.min(self.quantile_values)),
            "max": float(np.max(self.quantile_values)),
            "median": float(median),
            "mode": float(mode),
            "quantiles": [(float(p), float(v)) for p, v in zip(self.quantile_probs, self.quantile_values)]
        }
    
    def visualize_distribution(self, title=None):
        try:
            import matplotlib.pyplot as plt
            
            plt.figure(figsize=(10, 6))
            plt.scatter(self.quantile_probs, self.quantile_values, marker='o', color='blue')
            plt.plot(self.quantile_probs, self.quantile_values, 'b-', alpha=0.5)
            
            if title:
                plt.title(title)
            else:
                action_str = f"Action: {self.action}" if self.action is not None else "Root"
                plt.title(f"Quantile Distribution - {action_str} (Depth: {self.depth})")
                
            plt.xlabel("Quantile Probability")
            plt.ylabel("Quantile Value")
            plt.grid(True, alpha=0.3)

            summary = self.get_distribution_summary()
            info_text = (f"Mean: {summary['mean']:.4f}\n"
                         f"Std: {summary['std']:.4f}\n"
                         f"Median: {summary['median']:.4f}")
            
            plt.text(0.02, 0.95, info_text, transform=plt.gca().transAxes,
                     verticalalignment='top', bbox=dict(boxstyle='round', alpha=0.1))
            
            plt.tight_layout()
            return plt.gcf()
        except ImportError:
            print("matplotlib is not installed. Cannot visualize distribution.")
            return None


class PlanUResult(NamedTuple):
    terminal_state: State
    cum_reward: float
    trace: Trace
    trace_of_nodes: list[PlanUNode]
    tree_state: PlanUNode
    trace_in_each_iter: list[list[PlanUNode]] = None
    tree_state_after_each_iter: list[PlanUNode] = None
    aggregated_result: Optional[Hashable] = None


class PlanUAggregation(Generic[State, Action, Example], ABC):
    def __init__(self, retrieve_answer: Callable[[State], Hashable],
                 weight_policy: str = 'edge'):
        assert weight_policy in ['edge', 'edge_inverse_depth', 'uniform']
        self.retrieve_answer = retrieve_answer
        self.weight_policy = weight_policy

    def __call__(self, tree_state: PlanUNode[State, Action,Example]) -> Optional[Hashable]:
        answer_dict = defaultdict(lambda: 0)

        def visit(cur: PlanUNode[State, Action, Example]):
            if cur.state is None:
                return []
            if cur.is_terminal:
                answer = self.retrieve_answer(cur.state)
                if answer is None:
                    print("PlanUAggregation: no answer retrieved.")
                    return []
                if self.weight_policy == 'edge':
                    answer_dict[answer] += cur.reward
                elif self.weight_policy == 'edge_inverse_depth':
                    answer_dict[answer] += cur.reward / cur.depth
                elif self.weight_policy == 'uniform':
                    answer_dict[answer] += 1.0
                return [(answer, cur.depth)]
            depth_list = defaultdict(list)
            cur_list = []
            for child in cur.children:
                cur_list.extend(child_info := visit(child))
                for answer, depth in child_info:
                    depth_list[answer].append(depth)
            for answer, depths in depth_list.items():
                if self.weight_policy == 'edge':
                    answer_dict[answer] += cur.reward
                elif self.weight_policy == 'edge_inverse_depth':
                    answer_dict[answer] += cur.reward / np.mean(depths)
            return cur_list

        visit(tree_state)

        if len(answer_dict) == 0:
            return None
        return max(answer_dict, key=lambda answer: answer_dict[answer])

class PlanU(SearchAlgorithm, Generic[State, Action, Example]):
    def __init__(self,
         output_trace_in_each_iter: bool = False,
         w_exp: float = 1.,
         depth_limit: int = 8,
         n_iters: int = 10,
         cum_reward: Callable[[list[float]], float] = sum,
         calc_q: Callable[[list[float]], float] = np.mean,
         simulate_strategy: str | Callable[[list[float]], int] = 'max',
         output_strategy: str = 'max_reward',
         uct_with_fast_reward: bool = True,
         aggregator: Optional[PlanUAggregation] = None,
         disable_tqdm: bool = True,
         node_visualizer: Callable[[PlanUNode], dict] = lambda x: x.__dict__,
         n_atoms: int = 51,
         v_min: float = -10.0,
         v_max: float = 10.0,
         distortion_fn: Callable[[np.ndarray, np.ndarray], float] = None,
         risk_distortion: float = 0.0,
         log_distributions: bool = False,
         distribution_log_path: Optional[str] = None,
         visualize_key_nodes: bool = False,
         chain_propagate: bool = True,  
         ):
        super().__init__()
        self.world_model = None
        self.search_config = None
        self.output_trace_in_each_iter = output_trace_in_each_iter
        self.w_exp = w_exp
        self.depth_limit = depth_limit
        self.n_iters = n_iters
        self.cum_reward = cum_reward
        self.calc_q = calc_q
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.risk_distortion = risk_distortion
        self.disable_tqdm = disable_tqdm  
        self.node_visualizer = node_visualizer  
        self.aggregator = aggregator

        default_simulate_strategies: dict[str, Callable[[list[float]], int]] = {
            'max': lambda x: np.argmax(x),
            'sample': lambda x: np.random.choice(len(x), p=x/np.sum(x) if np.sum(x) > 0 else np.ones(len(x))/len(x)),
            'random': lambda x: np.random.choice(len(x)),
        }
        self.simulate_choice: Callable[[list[float]], int] = default_simulate_strategies.get(simulate_strategy,
                                                                                         simulate_strategy)

        assert output_strategy in ['max_reward', 'follow_max', 'max_visit', 'max_iter', 'last_iter',
                               'last_terminal_iter']
        self.output_strategy = output_strategy
        self.uct_with_fast_reward = uct_with_fast_reward
        self._output_iter: list[PlanUNode] = None
        self._output_cum_reward = -math.inf
        self.trace_in_each_iter: list[list[PlanUNode]] = None
        self.root: Optional[PlanUNode] = None

        if distortion_fn is None:
            if risk_distortion == 0.0:
                self.distortion_fn = lambda probs, values: np.sum(probs * values)
            else:
                def cvar_distortion(probs, values):
                    if risk_distortion > 0: 
                        threshold = 1.0 - risk_distortion
                        mask = probs >= threshold
                        if np.any(mask):
                            return np.sum(probs[mask] * values[mask]) / np.sum(probs[mask])
                        return np.max(values)
                    else: 
                        threshold = -risk_distortion
                        mask = probs <= threshold
                        if np.any(mask):
                            return np.sum(probs[mask] * values[mask]) / np.sum(probs[mask])
                        return np.min(values)
                self.distortion_fn = cvar_distortion
        else:
            self.distortion_fn = distortion_fn
            
        self.log_distributions = log_distributions
        self.distribution_logs = []

        if log_distributions and distribution_log_path:
            try:
                timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
                base_log_dir = os.path.dirname(distribution_log_path)
                if not os.path.exists(base_log_dir):
                    os.makedirs(base_log_dir, exist_ok=True)

                self.distribution_log_path = os.path.join(
                    os.path.dirname(distribution_log_path),
                    f"{timestamp}_{os.path.basename(distribution_log_path)}"
                )
                os.makedirs(self.distribution_log_path, exist_ok=True)
                print(f"log saved to: {self.distribution_log_path}")
            except Exception as e:
                print(f"errro: {e}")
                self.log_distributions = False
                self.distribution_log_path = None
        else:
            self.distribution_log_path = None
            
        self.visualize_key_nodes = visualize_key_nodes
        self.distribution_logs = []
        self.chain_propagate = chain_propagate
        

    def iterate(self, node: PlanUNode) -> list[PlanUNode]:
        path = self._select(node)
        # print(f"Path at start: {[n.state for n in path]}") 
        # print("Expanding path[-1]",path[-1])
        if not self._is_terminal_with_depth_limit(path[-1]):
            self._expand(path[-1])
            self._simulate(path)
        
        cum_reward = self._chain_back_propagate(path)
        if self.output_strategy == 'max_iter' and path[-1].is_terminal and cum_reward > self._output_cum_reward:
            self._output_cum_reward = cum_reward
            self._output_iter = path
        if self.output_strategy == 'last_iter':
            self._output_cum_reward = cum_reward
            self._output_iter = path
        if self.output_strategy == 'last_terminal_iter' and path[-1].is_terminal:
            self._output_cum_reward = cum_reward
            self._output_iter = path
        return path

    def _is_terminal_with_depth_limit(self, node: PlanUNode):
        return node.is_terminal or node.depth >= self.depth_limit

    def _select(self, node: PlanUNode) -> list[PlanUNode]:
        path = []
        while True:
            path.append(node)
            if node.children is None or len(node.children) == 0 or self._is_terminal_with_depth_limit(node):
                return path
            node = self._uct_select(node)
            
    def _epsilon_greedy_select(self, node: PlanUNode, epsilon: float) -> PlanUNode:
        p = random.random()
        if p < epsilon:
            return random.choice(node.children)
        return max(node.children, key=lambda x: x.Q)
    
    def _uct(self, node: PlanUNode) -> float:
        distorted_q = node.get_distorted_q(self.distortion_fn)
        return distorted_q + self.w_exp * np.sqrt(np.log(len(node.parent.cum_rewards)) / max(1, len(node.cum_rewards)))

    def _uct_select(self, node: PlanUNode) -> PlanUNode:
        if self.uct_with_fast_reward or all(x.state is not None for x in node.children):
            return max(node.children, key=self._uct)
        else:
            unvisited_children = filter(lambda x: x.state is None, node.children)
            return max(unvisited_children, key=lambda x: x.fast_reward)

    def _expand(self, node: PlanUNode):
        if node.state is None:
            node.state, aux = self.world_model.step(node.parent.state, node.action)
            
            node.reward, node.reward_details = self.search_config. \
                reward(node.parent, node.action, **node.fast_reward_details, **aux)
            node.is_terminal = self.world_model.is_terminal(node.state)

        if node.is_terminal:
            return

        children = []
        actions = self.search_config.get_actions(node.state)
        for action in actions:
            fast_reward, fast_reward_details = self.search_config.fast_reward(node, action)
            child = PlanUNode(state=None, action=action, parent=node,
                             fast_reward=fast_reward, fast_reward_details=fast_reward_details, 
                             calc_q=self.calc_q, n_atoms=self.n_atoms, v_min=self.v_min, v_max=self.v_max)
            children.append(child)

        node.children = children

    def _simulate(self, path: list[PlanUNode]):
        node = path[-1]
        while True:
            if node.state is None:
                self._expand(node)
            if self._is_terminal_with_depth_limit(node) or len(node.children) == 0:
                return
            fast_rewards = [child.fast_reward for child in node.children]
            node = node.children[self.simulate_choice(fast_rewards)]
            path.append(node)

    def _chain_back_propagate(self, path: list[PlanUNode]):
        rewards = []
        cum_reward = -math.inf
        for node in reversed(path):
            rewards.append(node.reward)
            cum_reward = self.cum_reward(rewards[::-1])
            node.cum_rewards.append(cum_reward)

            if node.state is not None:
                alpha = 1.0 / max(1, len(node.cum_rewards))
                
                for i, tau in enumerate(node.quantile_probs):
                    current_value = node.quantile_values[i]
                    delta = cum_reward - current_value
                    gradient = tau if delta > 0 else tau - 1.0
                    
                    node.quantile_values[i] += alpha * gradient * np.abs(delta)
                
                node.quantile_values = np.clip(node.quantile_values, node.v_min, node.v_max)

                node.distribution_history.append(node.quantile_values.copy())

                if self.log_distributions:
                    self.distribution_logs.append({
                        "node_id": node.id,
                        "depth": node.depth,
                        "action": str(node.action),
                        "reward": node.reward,
                        "cum_reward": cum_reward,
                        "quantile_probs": node.quantile_probs.tolist(),
                        "quantile_values": node.quantile_values.tolist(),
                        "visit_count": len(node.cum_rewards)
                    })
                    
                    if self.visualize_key_nodes and len(node.cum_rewards) % 10 == 0:
                        try:
                            fig = node.visualize_distribution(f"Node {node.id} - Visit {len(node.cum_rewards)}")
                            if fig and self.distribution_log_path:
                                import os
                                os.makedirs(self.distribution_log_path, exist_ok=True)
                                fig.savefig(f"{self.distribution_log_path}/node_{node.id}_visit_{len(node.cum_rewards)}.png")
                                plt.close(fig)
                        except Exception as e:
                            print(f"Failed to visualize distribution: {e}")
    
        if self.log_distributions and self.distribution_log_path:
            try:
                import os
                import json
                os.makedirs(os.path.dirname(self.distribution_log_path), exist_ok=True)
                with open(f"{self.distribution_log_path}/distribution_logs.json", 'w') as f:
                    json.dump(self.distribution_logs, f, cls=NumpyEncoder)
            except Exception as e:
                print(f"Failed to save distribution logs: {e}")
        
        return cum_reward

    def _back_propagate(self, path: list[PlanUNode]):
        if len(path) <= 1:
            return 0.0  
    
        current_node = path[-2]  
        next_node = path[-1]     
        reward = next_node.reward
        rewards = [node.reward for node in path]
        cum_reward = self.cum_reward(rewards)
        current_node.cum_rewards.append(cum_reward)
        if current_node.state is None:
            return cum_reward
        gamma = 0.99  
        if next_node.is_terminal:
            target_values = np.ones_like(current_node.quantile_values) * reward
        else:
            target_values = reward + gamma * next_node.quantile_values
        
        alpha = 1.0 / max(1, len(current_node.cum_rewards))
        
   
        for i, tau in enumerate(current_node.quantile_probs):
    
            current_value = current_node.quantile_values[i]

            for target_value in target_values:
                delta = target_value - current_value
                gradient = tau if delta > 0 else tau - 1.0
                current_node.quantile_values[i] += alpha * gradient * np.abs(delta)
        
        current_node.quantile_values = np.clip(current_node.quantile_values, current_node.v_min, current_node.v_max)

        current_node.distribution_history.append(current_node.quantile_values.copy())

        if self.log_distributions:
            self.distribution_logs.append({
                "node_id": current_node.id,
                "depth": current_node.depth,
                "action": str(current_node.action),
                "reward": current_node.reward,
                "cum_reward": cum_reward,
                "quantile_probs": current_node.quantile_probs.tolist(),
                "quantile_values": current_node.quantile_values.tolist(),
                "visit_count": len(current_node.cum_rewards)
            })
            
            if self.visualize_key_nodes and len(current_node.cum_rewards) % 10 == 0:
                try:
                    import matplotlib.pyplot as plt
                    fig = current_node.visualize_distribution(f"Node {current_node.id} - Visit {len(current_node.cum_rewards)}")
                    if fig and self.distribution_log_path:
                        import os
                        os.makedirs(self.distribution_log_path, exist_ok=True)
                        fig.savefig(f"{self.distribution_log_path}/node_{current_node.id}_visit_{len(current_node.cum_rewards)}.png")
                        plt.close(fig)
                except Exception as e:
                    print(f"failed: {e}")

        if len(path) > 2 and self.chain_propagate:
            self._back_propagate(path[:-1])
        
        return cum_reward
    
    
    def _dfs_max_reward(self, path: list[PlanUNode]) -> tuple[float, list[PlanUNode]]:
        cur = path[-1]
        if cur.is_terminal:
            return self.cum_reward([node.reward for node in path[1:]]), path
        if cur.children is None:
            return -math.inf, path
        visited_children = [x for x in cur.children if x.state is not None]
        if len(visited_children) == 0:
            return -math.inf, path
        return max((self._dfs_max_reward(path + [child]) for child in visited_children), key=lambda x: x[0])

    def search(self):
        self._output_cum_reward = -math.inf
        self._output_iter = None
        self.root = PlanUNode(state=self.world_model.init_state(), action=None, parent=None, 
                             calc_q=self.calc_q, n_atoms=self.n_atoms, v_min=self.v_min, v_max=self.v_max)
        
        if self.output_trace_in_each_iter:
            self.trace_in_each_iter = []

        self.distribution_logs = []

        for i in trange(self.n_iters, disable=self.disable_tqdm, desc='PlanU iteration', leave=False):
            path = self.iterate(self.root)
            if self.output_trace_in_each_iter:
                self.trace_in_each_iter.append(deepcopy(path))
        
        if self.log_distributions and self.distribution_log_path:
            try:
                log_file_path = os.path.join(self.distribution_log_path, "distribution_logs.json")
                with open(log_file_path, 'w') as f:
                    json.dump(self.distribution_logs, f, indent=2, cls=NumpyEncoder)
                print(f" {log_file_path}")
            except Exception as e:
                print(f" {e}")

        if self.visualize_key_nodes:
            self._visualize_key_nodes()

        if self.output_strategy == 'follow_max':
            self._output_iter = []
            cur = self.root
            while True:
                self._output_iter.append(cur)
                if cur.is_terminal:
                    break
                visited_children = [x for x in cur.children if x.state is not None]
                if len(visited_children) == 0:
                    break
                cur = max(visited_children, key=lambda x: x.reward)
            self._output_cum_reward = self.cum_reward([node.reward for node in self._output_iter[1::-1]])
        if self.output_strategy == 'max_reward':
            self._output_cum_reward, self._output_iter = self._dfs_max_reward([self.root])
            if self._output_cum_reward == -math.inf:
                self._output_iter = None
    
    def _visualize_key_nodes(self):
        try:
            import matplotlib.pyplot as plt
            import os
            
            if self.distribution_log_path:
                save_dir = os.path.join(self.distribution_log_path, "visualizations")
            else:
                save_dir = "visualizations"
            os.makedirs(save_dir, exist_ok=True)

            if len(self.root.distribution_history) > 1:
                self._visualize_node_distribution_history(self.root, os.path.join(save_dir, "root_node.png"))

            if self._output_iter:
                for i, node in enumerate(self._output_iter):
                    if len(node.distribution_history) > 1:
                        self._visualize_node_distribution_history(
                            node, 
                            os.path.join(save_dir, f"path_node_{i}_depth_{node.depth}.png")
                        )

            high_reward_nodes = self._find_high_reward_nodes()
            for i, node in enumerate(high_reward_nodes):
                if len(node.distribution_history) > 1:
                    self._visualize_node_distribution_history(
                        node, 
                        os.path.join(save_dir, f"high_reward_node_{i}_depth_{node.depth}.png")
                    )
            
            print(f"picture saved in{save_dir}")
            
        except ImportError:
            print("matplotlib is not installed. Cannot visualize distributions.")
    
    def _visualize_node_distribution_history(self, node: PlanUNode, save_path: str):
        try:
            import matplotlib.pyplot as plt
            import matplotlib.animation as animation
            from matplotlib.animation import FuncAnimation
            
            if len(node.distribution_history) == 0:
                print("No distribution history to visualize.")
                return
            
            fig, ax = plt.subplots(figsize=(10, 6))

            scatter = ax.scatter(node.quantile_probs, node.distribution_history[0], 
                                color='blue', alpha=0.7, s=50)
            line, = ax.plot(node.quantile_probs, node.distribution_history[0], 
                        'b-', alpha=0.5)

            action_str = f"Action: {node.action}" if node.action is not None else "Root"
            ax.set_title(f"Quantile Distribution History - {action_str} (Depth: {node.depth})")
            ax.set_xlabel("Quantile Probability")
            ax.set_ylabel("Quantile Value")
            ax.grid(True, alpha=0.3)

            mean_val = np.mean(node.distribution_history[0])
            std_val = np.std(node.distribution_history[0])
            info_text = ax.text(0.02, 0.95, f"Mean: {mean_val:.4f}\nStd: {std_val:.4f}\nFrame: 0/{len(node.distribution_history)-1}",
                            transform=ax.transAxes, verticalalignment='top',
                            bbox=dict(boxstyle='round', alpha=0.1))

            all_values = np.concatenate(node.distribution_history)
            y_min = min(node.v_min, np.min(all_values))
            y_max = max(node.v_max, np.max(all_values))
            margin = (y_max - y_min) * 0.1
            ax.set_ylim(y_min - margin, y_max + margin)
            
            def update(frame):
                scatter.set_offsets(np.column_stack((node.quantile_probs, node.distribution_history[frame])))
                line.set_ydata(node.distribution_history[frame])
                
                mean_val = np.mean(node.distribution_history[frame])
                std_val = np.std(node.distribution_history[frame])
                info_text.set_text(f"Mean: {mean_val:.4f}\nStd: {std_val:.4f}\nFrame: {frame}/{len(node.distribution_history)-1}")
                
                return scatter, line, info_text
            
            ani = FuncAnimation(fig, update, frames=len(node.distribution_history),
                            interval=200, blit=True)
            
            try:
                ani.save(save_path.replace('.png', '.gif'), writer='pillow', fps=5)
                print(f"Animation saved to {save_path.replace('.png', '.gif')}")
            except Exception as e:
                print(f"Failed to save animation: {e}")

            plt.figure(figsize=(10, 6))
            plt.scatter(node.quantile_probs, node.distribution_history[-1], 
                    color='blue', alpha=0.7, s=50)
            plt.plot(node.quantile_probs, node.distribution_history[-1], 
                    'b-', alpha=0.5)
            plt.title(f"Final Quantile Distribution - {action_str} (Depth: {node.depth})")
            plt.xlabel("Quantile Probability")
            plt.ylabel("Quantile Value")
            plt.grid(True, alpha=0.3)

            mean_val = np.mean(node.distribution_history[-1])
            std_val = np.std(node.distribution_history[-1])
            plt.text(0.02, 0.95, f"Mean: {mean_val:.4f}\nStd: {std_val:.4f}",
                    transform=plt.gca().transAxes, verticalalignment='top',
                    bbox=dict(boxstyle='round', alpha=0.1))
            
            plt.tight_layout()
            plt.savefig(save_path)
            plt.close('all')
            print(f"Final frame saved to {save_path}")
            
        except ImportError as e:
            print(f"Required visualization libraries not installed: {e}")
        except Exception as e:
            print(f"Error visualizing distribution history: {e}")
    
    def _find_high_reward_nodes(self, top_k=5):
        all_nodes = []
        
        def collect_nodes(node):
            if node.state is not None and len(node.cum_rewards) > 0:
                all_nodes.append(node)
            if node.children:
                for child in node.children:
                    collect_nodes(child)
        
        collect_nodes(self.root)
        
        all_nodes.sort(key=lambda x: max(x.cum_rewards) if x.cum_rewards else -float('inf'), reverse=True)
        
        return all_nodes[:top_k]

    def __call__(self,
                 world_model: WorldModel[State, Action, Example],
                 search_config: SearchConfig[State, Action, Example],
                 log_file: Optional[str] = None,
                 **kwargs) -> PlanUResult:
        PlanUNode.reset_id()
        self.world_model = world_model
        self.search_config = search_config

        self.search()

        if self._output_iter is None:
            terminal_state = trace = None
        else:
            terminal_state = self._output_iter[-1].state
            trace = [node.state for node in self._output_iter], [node.action for node in self._output_iter[1:]]
            
        if self.output_trace_in_each_iter:
            trace_in_each_iter = self.trace_in_each_iter
            tree_state_after_each_iter = [trace[0] for trace in trace_in_each_iter]
        else:
            trace_in_each_iter = tree_state_after_each_iter = None
            
        result = PlanUResult(terminal_state=terminal_state,
                            cum_reward=self._output_cum_reward,
                            trace=trace,
                            trace_of_nodes=self._output_iter,
                            tree_state=self.root,
                            trace_in_each_iter=trace_in_each_iter,
                            tree_state_after_each_iter=tree_state_after_each_iter)
        if self.aggregator is not None:
            result = PlanUResult(
                terminal_state=result.terminal_state,
                cum_reward=result.cum_reward,
                trace=result.trace,
                trace_of_nodes=result.trace_of_nodes,
                tree_state=result.tree_state,
                trace_in_each_iter=result.trace_in_each_iter,
                tree_state_after_each_iter=result.tree_state_after_each_iter,
                aggregated_result=self.aggregator(result.tree_state),
            )
        return result
