import os
import copy
import math
import time
from copy import deepcopy
from platform import node
from typing import Generic, Optional, NamedTuple, Callable, Hashable
from abc import ABC
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from typing import Callable, Any, Optional
import torch
import numpy as np
from langagent.reasoner_base import WorldModel, Policy, RewardModel, State, Action, Example, Trace
from langagent.search.node import MCTSNode, _TYPE_REGISTRY, SearchNode
from langagent.search.config import BaseSearchConfig
from langagent.search.common import visualize_node, visualize_path, _sample_actions_with_existing, _world_modeling, _is_terminal_with_depth_limit, _is_terminal_with_depth_limit_and_r_threshold
import logging
from tqdm import trange 
from langagent.search.continuation import _continuation

logger = logging.getLogger(__name__)

class MCTSResult(NamedTuple):
    cum_reward: float
    trace: Trace
    trace_of_nodes: list[MCTSNode]
    root: MCTSNode
    trace_in_each_iter: list[list[MCTSNode]] = None
    unselected_terminal_paths_during_simulate: list[list[MCTSNode]] = None

def get_result_from_mcts( root: MCTSNode[State, Action], question, retrieve_answer, weight_policy: str = 'edge') -> Optional[Hashable]:
    assert weight_policy in ['edge', 'edge_inverse_depth']
    answer_dict = defaultdict(lambda: 0)

    def visit(cur: MCTSNode[State, Action]):
        if cur.state is None:
            return []
        if cur.is_terminal:
            answer = retrieve_answer(cur.state, question)
            if weight_policy == 'edge':
                answer_dict[answer] += cur.reward
            elif weight_policy == 'edge_inverse_depth':
                answer_dict[answer] += cur.reward / cur.depth
            return [(answer, cur.depth)]
        depth_list = defaultdict(list)
        cur_list = []
        for child in cur.children:
            cur_list.extend(child_info := visit(child))
            for answer, depth in child_info:
                depth_list[answer].append(depth)
        for answer, depths in depth_list.items():
            if weight_policy == 'edge':
                answer_dict[answer] += cur.reward
            elif weight_policy == 'edge_inverse_depth':
                answer_dict[answer] += cur.reward / np.mean(depths)
        return cur_list

    visit(root)

    if len(answer_dict) == 0:
        return None
    return max(answer_dict, key=lambda answer: answer_dict[answer])

# ~~~~~~~ Search Config (BEGIN ~~~~~~~~~
# --- registries to reconstruct callables by name ---
FUNC_REGISTRY = {
    "sum": sum,
    "max": max,
    "np.mean": np.mean,
    "np.argmax": np.argmax,
    "np.random.choice": np.random.choice,  # rarely used directly
}

def _func_to_name(f: Callable) -> str:
    # map known functions to stable names
    for name, fn in FUNC_REGISTRY.items():
        if f is fn:
            return name
    # fallback: module.qualname when possible (still just a string for JSON)
    mod = getattr(f, "__module__", None)
    qn = getattr(f, "__qualname__", None) or getattr(f, "__name__", None)
    if mod and qn:
        return f"{mod}.{qn}"
    raise TypeError(f"Unrecognized callable: {f}. Add it to FUNC_REGISTRY.")

def _name_to_func(name: str) -> Callable:
    if name in FUNC_REGISTRY:
        return FUNC_REGISTRY[name]
    # optional: try dynamic import if you really need it
    raise KeyError(f"Callable '{name}' not in FUNC_REGISTRY. Add it first.")

@dataclass
class MCTSSearchConfig(BaseSearchConfig):
    """
    MCTS-specific search configuration
    """
    # selection
    w_exp: float = 1.
    uct_with_fast_reward: bool = True
    n_iters: int = 10
    
    # simulation
    roll_out_steps: int = 10000
    cum_reward: Callable = sum
    calc_q: Callable = np.mean
    default_simulate_strategies: dict = field(default_factory=lambda: {
        'max': lambda x: np.argmax(x),
        'sample': lambda x: np.random.choice(len(x), p=x),
        'random': lambda x: np.random.choice(len(x)),
    })
    simulate_strategy: str = 'max'
    simulate_choice: Any = field(init=False)
    n_action_for_simulate: int = 1
    n_confidence: int = 1
    
    # output
    output_strategy: str = 'max_reward'
    output_trace_in_each_iter: bool = True

    use_critic: bool = False

    
    def __post_init__(self):
        self.simulate_choice = self.default_simulate_strategies.get(self.simulate_strategy, self.simulate_strategy)

    def verify(self):
        assert self.output_strategy in [
            'max_reward', 'follow_max', 'max_visit', 'max_iter', 'last_iter', 'last_terminal_iter'
        ]
        
    def to_dict(self) -> dict:
        d = asdict(self)
        # drop non-serializable / runtime-only fields
        d.pop('default_simulate_strategies', None)
        d.pop('simulate_choice', None)
        # store callables by name
        d['cum_reward'] = _func_to_name(self.cum_reward)
        d['calc_q'] = _func_to_name(self.calc_q)
        return d
# ~~~~~~~ Search Config (BEGIN ~~~~~~~~~

##### SELECT (Begin) #####
def _select(w_exp: float, node: MCTSNode, depth_limit: int, force_terminating_on_depth_limit: bool) -> list[MCTSNode]:
    logger.debug("\n=========== [Select Begin] ===========")
    def _uct_select(w_exp: float, node: MCTSNode, return_detail=False) -> MCTSNode:
        best_child = None
        best_score = -np.inf
        num_trials_parent= len(node.cum_rewards)
        best_detail = ""
        for i, child in enumerate(node.children):
            num_trials_cur = len(child.cum_rewards)
            exploration_score = np.sqrt(np.log(num_trials_parent) / max(1, num_trials_cur))
            score = child.Q + w_exp * exploration_score
            
            if score > best_score:
                best_score = score
                best_child = child
                best_detail = f"(ID: {child.id}) - Q: {child.Q:.3f}, Exploration: {exploration_score:.3f}, Score: {score:.3f})"
        if return_detail:
            return best_child, best_detail
        return best_child
    path = []
    record_select_types = []
    while True:   
        path.append(node)
        
        if node.children is None or len(node.children) <= 0 or \
            _is_terminal_with_depth_limit(node, depth_limit, force_terminating_on_depth_limit):

            logger.debug(visualize_path(path))
            select_types_str = "->".join(record_select_types)
            logger.debug(f"Select Types: {select_types_str}")
            logger.debug("=========== [Select End] ===========\n")
            return path
        
        # continuous select
        if node.children[0].is_continuous:
            assert len(node.children) == 1 
            node = node.children[0]  # only one child in continuous mode
            record_select_types.append('continuation')
            continue
        
        ### uct-select the next node ###
        if all(x.state is not None for x in node.children):
            logger.debug(f"All children of node {node.id} are visited, using UCT select.")
            
            node, select_detail = _uct_select(w_exp, node, return_detail=True)
            record_select_types.append('uct' + select_detail)
        else: # if unvisited children exists, select an unvisited child with the highest fast reward (no reward/state via reward&transition model)
            logger.debug(f"Unvisited children exist for node {node.id}, selecting based on fast reward.")
            record_select_types.append('unvisited/fast_reward')
            unvisited_children = filter(lambda x: x.state is None, node.children)
            node = max(unvisited_children, key=lambda x: x.fast_reward)
##### SELECT (End) #####


##### EXPAND (Begin) #####
def _expand(
    example, 
    example_idx, 
    node, 
    policy, 
    n_actions, 
    reward_model, 
    world_model=None, 
    assign_rewards=True, 
    use_critic=False, 
    from_phase="expand"
):
    logger.debug(f"\n=========== [Expand for Example {example_idx} Begin] ===========")

    new_actions = _sample_actions_with_existing(
        example,
        example_idx,
        node,
        policy,
        n_actions,
        world_model=world_model,
        use_critic=use_critic,
        from_phase=from_phase
    )
    
    for action in new_actions:
        child = MCTSNode(state=None, action=action, parent=node)
        # Assign terminal-for-repeat
        child.is_terminal_for_repeat = (action == "ALWAY REPEAT. TERMINATE")

        # assign rewards
        if assign_rewards:
            assert child.fast_reward == -1, "fast_reward should be -1 for newly created child"
            logger.debug(f"Assigning fast reward for newly created child: Node {child.id}")
            fast_reward, fast_reward_details = reward_model.fast_reward(
                example, example_idx, node.state, action, from_phase=from_phase
            ) # action evaluation, e.g., usefulness of a subquestion
            child.fast_reward = fast_reward
            child.fast_reward_details = fast_reward_details
        else:
            logger.debug(f"assign_rewards is False, skipping fast reward assignment for child: Node {child.id}")

        if from_phase == "simulate":
            child.from_simulate  = True
        elif from_phase == "expand":
            child.from_expand = True 
        elif from_phase == "continuation":
            child.from_continuation = True
        else:
            raise ValueError(f"from_phase should be 'expand' or 'simulate' or 'continuation', got {from_phase}")
        
        node.children.append(child)
        logger.debug(visualize_node(child))
    
    # Step 4: Ensure existing children have the required attributes
    for child in node.children:
        if child.fast_reward == -1:
            if assign_rewards:
                fast_reward, fast_reward_details = reward_model.fast_reward(
                    example, example_idx, node.state, child.action, from_phase=from_phase
                )
                child.fast_reward = fast_reward
                child.fast_reward_details = fast_reward_details
                logger.debug(f"Child's (Node {child.id}) fast_reward now assigned as {fast_reward}")
            else:
                logger.debug(f"Child's (Node {child.id}) fast_reward not been assigned and not required to be assigned")
        else:
            logger.debug(f"Child's (Node {child.id}) fast_reward already assigned as {child.fast_reward}")
    logger.debug("=========== [Expand End] ===========\n")
##### EXPAND (END) #####

##### SIMULATE (Begin) (REUSE EXPAND...) #####
def _simulate(
    example, 
    example_idx, 
    path, 
    mcts_search_config, 
    world_model, 
    policy, 
    reward_model, 
    use_critic=False, 
    roll_out_steps=10000
):
    
    assert path[-1].state is not None, "node.state should not be None for rollout"

    logger.debug("\n=========== [Simulate Begin] ===========")
    node = path[-1]
    unselected_terminal_paths = []
    for i in range(roll_out_steps):
        logger.debug(f"Rollout Step {i+1}")
        
        _expand(
            example, 
            example_idx, 
            node, 
            policy, 
            n_actions=mcts_search_config.n_action_for_simulate,
            reward_model=reward_model, 
            world_model=world_model,
            assign_rewards=True,
            use_critic=use_critic,
            from_phase="simulate"
        )

        if node.is_terminal_for_repeat:
            logger.debug(f"!!!!! is_terminal_for_repeat")
            logger.debug("=========== [Simulate End] ===========\n")
            return True, unselected_terminal_paths
        
        fast_rewards = [child.fast_reward for child in node.children]
        selected_idx = mcts_search_config.simulate_choice(fast_rewards)
        node = node.children[selected_idx]
        node.is_simulated = True
        _world_modeling(example, example_idx, node, world_model, reward_model, from_phase="simulate")
        logger.debug(f"NEW NODE Transfer with the action: {node.action}. The resulting state: {node.state}")
        path.append(node)

        for i in range(len(node.children)):
            if i != selected_idx and node.children[i].is_terminal:
                unselected_terminal_paths.append(deepcopy(path + [node.children[i]]))
        # ====== Terminate Check (Begin) ======
        if _is_terminal_with_depth_limit_and_r_threshold(node,  mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit, mcts_search_config.r_terminating):
            logger.debug("=========== [Simulate End] ===========\n")
            return False, unselected_terminal_paths
        # ====== Terminate Check (End) ======
    
    logger.debug("=========== [Simulate End] ===========\n")
    return False, unselected_terminal_paths
##### SIMULATE (END) #####

##### BACK-PROPAGATE (BEGIN) #####
def _back_propagate(path: list[MCTSNode], cum_reward_func):
    logger.debug("\n=========== [Backpropagate Begin] ===========")
    rewards = []
    cum_rewards_appened = []
    for node in reversed(path):
        rewards.append(node.reward)
        node.cum_rewards.append(cum_reward_func(rewards[::-1]))
        cum_rewards_appened.append(cum_reward_func(rewards[::-1]))
    logger.debug(f"Rewards (leaf -> root): {rewards}")
    logger.debug(f"Cumulative rewards appended to the nodes (leaf -> root): {cum_rewards_appened}")
    logger.debug("=========== [Backpropagate End] ===========\n")
    return node.cum_rewards[-1]
##### BACK-PROPAGATE (END)

##### BACK-PROPAGATE (BEGIN) #####
# https://github.com/THUDM/ReST-MCTS/blob/main/MCTS/mcts.py#L213
# def rest_back_propagate(node):
#     while node is not None:
#         node.numVisits += 1
#         if node.isFullyExpanded:
#             child_Vs = [child.V * child.numVisits for child in node.children.values()]
#             total_num_visits = sum([child.numVisits for child in node.children.values()])
#             if total_num_visits > 0:
#                 node.V = sum(child_Vs) / total_num_visits
#         node = node.parent
##### BACK-PROPAGATE (END)

##### MCTS (BEGIN) #####
def mcts(example, example_idx, mcts_search_config, world_model: WorldModel, policy: Policy, reward_model: RewardModel, bn_evaluator=None) -> MCTSResult:
    logger.debug(f"Question: {example}")
    logger.debug(f"\n\n\n=========== [MCTS for Example {example_idx} Begin] ===========")
    
    MCTSNode.set_default_calc_q(mcts_search_config.calc_q)
    
    def _dfs_max_reward(path: list[MCTSNode]) -> tuple[float, list[MCTSNode]]:
        cur = path[-1]
        if cur.is_terminal:
            return mcts_search_config.cum_reward([node.reward for node in path[1:]]), path
        if cur.children is None:
            return -math.inf, path
        visited_children = [x for x in cur.children if x.state is not None]
        if len(visited_children) == 0:
            return -math.inf, path
        return max((_dfs_max_reward(path + [child]) for child in visited_children), key=lambda x: x[0])
    
    # updated during search
    _output_cum_reward = -math.inf,
    _output_iter = None
    SearchNode.reset_id() # MCTSNode.reset_id() only resets MCTSNode.id_iter, not SearchNode.id_iter. But my constructor always does next(SearchNode.id_iter)
    
    root = MCTSNode(state=world_model.init_state(), action=example, parent=None)
    assert root.id == 0, f"Root node ID should be 0 not {root.id}"
    
    trace_in_each_iter = []
    unselected_terminal_paths_during_simulate = []
    start_time = time.time()   # <--- record start time
    try:  
        for idx_iter in trange(mcts_search_config.n_iters, desc='MCTS iteration', leave=False):
            if mcts_search_config.runtime_limit_before_iter and time.time() - start_time > mcts_search_config.runtime_limit_before_iter: 
                raise ValueError(f"MCTS exceeded runtime limit: {mcts_search_config.runtime_limit_before_iter}")  # will be caught by except below
            logger.debug(f"\n\n\n=========== [MCTS iteration {idx_iter} Begin] ===========")
            is_terminal_for_repeat = False
            path = _select(mcts_search_config.w_exp, root, mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit)  ####### select
            
            # ====== Terminate Check (Begin) ======
            if _is_terminal_with_depth_limit_and_r_threshold(path[-1], mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit, mcts_search_config.r_terminating):
                trace_in_each_iter.append(deepcopy(path))
                if mcts_search_config.terminate_on_terminal_node:
                    logger.debug(f"!!!!! The MCTS terminates due to terminal node")
                    break
                else:
                    logger.debug(f"!!!!! The MCTS continues to next iteration due to terminal node")
                    continue
            # ====== Terminate Check (End) ======

            if mcts_search_config.add_continuation:
                # no branching; no exploration selection
                continuous_trace = _continuation(
                    example, 
                    example_idx, 
                    path[-1], 
                    world_model, 
                    policy, 
                    reward_model, 
                    expand_func=_expand, 
                    world_modeling_func=_world_modeling, 
                    bn_evaluator=bn_evaluator, 
                    depth_limit=mcts_search_config.depth_limit,
                    threshold_alpha=mcts_search_config.reward_alpha, 
                    threshold_conf=mcts_search_config.reward_beta, 
                    threshold_gamma=mcts_search_config.reward_gamma,
                    threshold_gamma1=mcts_search_config.reward_gamma1,
                    n_actions_for_bne=mcts_search_config.n_actions_for_bne,
                    use_critic=mcts_search_config.use_critic)
                path.extend(continuous_trace[1:]) # the 1st node is the last node from selection    

                # ====== Terminate Check (Begin) ======
                if _is_terminal_with_depth_limit_and_r_threshold(path[-1], mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit, mcts_search_config.r_terminating):
                    trace_in_each_iter.append(deepcopy(path))
                    if mcts_search_config.terminate_on_terminal_node:
                        logger.debug(f"!!!!! The MCTS terminates due to terminal node")
                        break
                    else:
                        logger.debug(f"!!!!! The MCTS continues to next iteration due to terminal node")
                        continue
                # ====== Terminate Check (End) ======
       
            # ====== Expansion (Begin) ======
            if path[-1].state is None:
                _world_modeling(example, example_idx, path[-1], world_model, reward_model, from_phase="expand")
            # ====== Terminate Check (Begin) ======
            if _is_terminal_with_depth_limit_and_r_threshold(path[-1], mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit, mcts_search_config.r_terminating):
                trace_in_each_iter.append(deepcopy(path))
                if mcts_search_config.terminate_on_terminal_node:
                    logger.debug(f"!!!!! The MCTS terminates due to terminal node")
                    break
                else:
                    logger.debug(f"!!!!! The MCTS continues to next iteration due to terminal node")
                    continue
            # ====== Terminate Check (End) ======
    
            _expand(
                example, 
                example_idx, 
                path[-1], 
                policy, 
                n_actions=policy.n_actions,
                reward_model=reward_model, 
                world_model=world_model,
                assign_rewards=True,
                use_critic=mcts_search_config.use_critic,
                from_phase="expand"
            ) ####### expand
            # ====== Expansion (End) ======

            # ====== Simulate (Begin) ======
            if path[-1].state is None:
                _world_modeling(example, example_idx, path[-1], world_model, reward_model, from_phase="expand")
            # ====== Terminate Check (Begin) ======
            if _is_terminal_with_depth_limit_and_r_threshold(path[-1], mcts_search_config.depth_limit, mcts_search_config.force_terminating_on_depth_limit, mcts_search_config.r_terminating):
                trace_in_each_iter.append(deepcopy(path))
                if mcts_search_config.terminate_on_terminal_node:
                    logger.debug(f"!!!!! The MCTS terminates due to terminal node")
                    break
                else:
                    logger.debug(f"!!!!! The MCTS continues to next iteration due to terminal node")
                    continue
            # ====== Terminate Check (End) ======
            is_terminal_for_repeat, unselected_terminal_paths = _simulate(
                example, 
                example_idx, 
                path, 
                mcts_search_config,
                world_model, 
                policy, 
                reward_model, 
                use_critic=mcts_search_config.use_critic, 
                roll_out_steps=mcts_search_config.roll_out_steps
            )  ####### simulate
            # ====== Simulate (End) ======

       
            cum_reward = _back_propagate(path, mcts_search_config.cum_reward)
            
            ##### Save trace in this iteration  #####
            trace_in_each_iter.append(deepcopy(path))
            unselected_terminal_paths_during_simulate.extend(unselected_terminal_paths)
            ##### Save trace in this iteration (END) #####
            
    except (ValueError, torch.cuda.OutOfMemoryError) as e:
        if isinstance(e, torch.cuda.OutOfMemoryError):
            # OOM handling
            torch.cuda.empty_cache() 
        
        msg = str(e) 
        logger.debug(msg)
        trace_in_each_iter.append([deepcopy(root)])
    num_hour_used = (time.time() - start_time) / 3600
    logger.debug(f"Used Hours: {num_hour_used}")
     
    # retrieve the path with maximum cumulative reward
    if mcts_search_config.output_strategy == 'max_reward':
        _output_cum_reward, _output_iter = _dfs_max_reward([root])
    
    logger.debug(f"=========== [MCTS for Example {example_idx} End] ===========\n")
        
    result = MCTSResult(cum_reward=_output_cum_reward,
                        trace=([node.state for node in _output_iter], [node.action for node in _output_iter[1:]]) if _output_iter is not None else None,
                        trace_of_nodes=_output_iter,
                        root=root,
                        trace_in_each_iter=trace_in_each_iter,
                        unselected_terminal_paths_during_simulate=unselected_terminal_paths_during_simulate)
        
    return result
##### MCTS (END) #####