import logging
import math
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional

from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel, Field
from tenacity import retry, wait_random_exponential, stop_after_attempt

from util import print_detailed_node_analysis, print_tree_visualization, print_search_analytics

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

MAX_RETRIES = 3


def print_search_analytics(final_state: Dict[str, Any]) -> None:
    print("\n" + "=" * 80)
    print("SEARCH ANALYTICS")
    print("=" * 80)

    if isinstance(final_state, dict):
        mcts_tree = final_state['mcts_tree']
    else:
        mcts_tree = final_state.mcts_tree

    if not mcts_tree:
        print("No MCTS tree available for analysis.")
        return

    analytics = mcts_tree.get_search_analytics()

    print(f"Search Summary:")
    print(f"  - Total Records: {analytics['search_summary']['total_records']}")
    print(f"  - Duration: {analytics['search_summary']['total_duration_seconds']:.2f} seconds")
    print(f"  - Iterations: {analytics['search_summary']['iterations']}")

    print(f"\nAction Breakdown:")
    for action, count in analytics['action_breakdown'].items():
        print(f"  - {action}: {count}")

    print(f"\nTree Statistics:")
    print(f"  - Total Nodes: {analytics['tree_statistics']['total_nodes']}")
    print(f"  - Total Visits: {analytics['tree_statistics']['total_visits']}")
    print(f"  - Average Reward: {analytics['tree_statistics']['average_reward']:.3f}")

    print(f"\nTop 5 Nodes by Average Reward:")
    for i, node_info in enumerate(analytics['top_nodes'], 1):
        print(
            f"  {i}. Node: {node_info['node_id']}, Visits: {node_info['visits']}, Avg Reward: {node_info['avg_reward']:.3f}")


def print_tree_visualization(final_state: Dict[str, Any]) -> None:
    print("\n" + "=" * 80)
    print("TREE STRUCTURE VISUALIZATION")
    print("=" * 80)

    if isinstance(final_state, dict):
        mcts_tree = final_state['mcts_tree']
        nodes = final_state['nodes']
    else:
        mcts_tree = final_state.mcts_tree
        nodes = final_state.nodes

    if not mcts_tree or not nodes:
        print("No tree data available for visualization.")
        return

    root_node_id = list(nodes.keys())[0]
    tree_viz = mcts_tree.visualize_tree(root_node_id, nodes)
    print(tree_viz)


def print_detailed_node_analysis(final_state: Dict[str, Any]) -> None:
    print("\n" + "=" * 80)
    print("DETAILED NODE ANALYSIS")
    print("=" * 80)

    if isinstance(final_state, dict):
        mcts_tree = final_state['mcts_tree']
        nodes = final_state['nodes']
    else:
        mcts_tree = final_state.mcts_tree
        nodes = final_state.nodes

    if not nodes:
        print("No nodes available for analysis.")
        return

    for i, (node_id, node_data) in enumerate(nodes.items(), 1):

        if isinstance(node_data, dict):
            depth = node_data.get('depth', 0)
            parent_id = node_data.get('parent_id')
            evaluations = node_data.get('evaluations', {})
            metadata = node_data.get('metadata')
            plan = node_data.get('plan', 'Unknown Plan')
        else:
            depth = node_data.depth
            parent_id = node_data.parent_id
            evaluations = node_data.evaluations
            metadata = node_data.metadata
            plan = node_data.plan

        print(f"\nNode {i} (ID: {node_id}):")
        print(f"  Depth: {depth}")
        print(f"  Parent: {'Root' if not parent_id else parent_id}")

        if metadata:
            if isinstance(metadata, dict):
                creation_time = metadata.get('creation_time')
                creation_iteration = metadata.get('creation_iteration', 0)
                total_evaluations = metadata.get('total_evaluations', 0)
            else:
                creation_time = metadata.creation_time
                creation_iteration = metadata.creation_iteration
                total_evaluations = metadata.total_evaluations

            if creation_time:
                if isinstance(creation_time, str):
                    print(f"  Created: {creation_time}")
                else:
                    print(f"  Created: {creation_time.strftime('%H:%M:%S')}")
            print(f"  Iteration: {creation_iteration}")
            print(f"  Evaluations: {total_evaluations}")

        if mcts_tree:
            visits = mcts_tree.N.get(node_id, 0)
            total_reward = mcts_tree.Q.get(node_id, 0.0)
            avg_reward = total_reward / visits if visits > 0 else 0.0
            print(f"  MCTS: Visits={visits}, Total_Q={total_reward:.3f}, Avg_Q={avg_reward:.3f}")

        if evaluations:
            print(f"  Evaluation Scores:")
            for eval_name, eval_result in evaluations.items():
                if isinstance(eval_result, dict):
                    score = eval_result.get('score', 0)
                    feedback = eval_result.get('feedback', '')
                else:
                    score = eval_result.score
                    feedback = eval_result.feedback

                escaped_feedback = feedback.replace('\n', '\\n')
                print(f"    - {eval_name}: {score:.1f}/100")
                print(f"      Feedback: {escaped_feedback[:500]}..." if len(
                    escaped_feedback) > 500 else f"      Feedback: {escaped_feedback}")

        escaped_plan = plan.replace('\n', '\\n')
        print(f"  Plan: {escaped_plan[:500]}..." if len(escaped_plan) > 500 else f"  Plan: {escaped_plan}")


@dataclass
class SearchRecord:
    iteration: int
    timestamp: datetime
    action: str
    node_id: str
    details: Dict[str, Any]


@dataclass
class NodeMetadata:
    creation_time: datetime
    creation_iteration: int
    total_evaluations: int = 0
    last_updated: Optional[datetime] = None


class PromptConfig(BaseModel):
    logical_consistency_system: str = (
        "You are an expert plan evaluator specializing in logical consistency and plan structure. "
        "Keep all responses under 500 words."
    )
    logical_consistency_user: str = (
        "# Problem Context\n{question}\n\n"
        "# Plan to Evaluate\n{plan}\n\n"
        "{history_section}"
        "# Your Task\n"
        "Your job is to identify logical flaws, structural issues, and missing steps in the plan. Focus on ensuring the plan is logically sound and well-structured at a strategic level.\n\n"
        "**Evaluation Criteria:**\n"
        "1.  **Logical Flow**: Verify that steps follow a logical sequence where each step builds upon the previous ones and leads naturally to the next.\n"
        "2.  **Contradictions**: Identify any contradictions, impossible steps, or conflicting instructions within the plan.\n"
        "3.  **Completeness**: Check for any critical missing steps or logical gaps that would prevent successful problem resolution.\n"
        "4.  **Step Structure**: Ensure steps are properly numbered (1., 2., 3.) with sub-steps clearly indicated (1.1, 1.2, etc.) when needed.\n"
        "5.  **Abstraction Level**: Verify that the plan maintains appropriate abstraction without diving into implementation details.\n"
        "6.  **Feedback Adherence**: If history is provided, verify if the new plan has successfully addressed the previous logical concerns.\n\n"
        "# Required Output Format\n"
        "1. Keep your entire response under 500 words.\n"
        "2. Provide your evaluation as a structured response with the following keys:\n"
        "    - `score`: A score from 0-100 based on the rubric (0-60: Unacceptable, 60-70: Major Issues, 70-80: Minor Issues, 80-100: Good).\n"
        "    - `feedback`: A brief, high-level explanation focusing on logical flow and structural soundness.\n"
        "    - `suggestions`: A list of specific, actionable suggestions to improve logical consistency and plan structure."
    )

    feasibility_system: str = (
        "You are an expert plan evaluator specializing in feasibility, strategic clarity, and practical guidance. "
        "Keep all responses under 500 words."
    )
    feasibility_user: str = (
        "# Problem Context\n{question}\n\n"
        "# Plan to Evaluate\n{plan}\n\n"
        "{history_section}"
        "# Your Task\n"
        "Your job is to assess if the plan provides clear strategic guidance while maintaining appropriate abstraction. Focus on ensuring the plan offers comprehensive direction without diving into implementation details.\n\n"
        "**Evaluation Criteria:**\n"
        "1.  **Strategic Clarity**: Each step must provide clear strategic direction that guides toward the solution (avoid implementation details or specific technical execution).\n"
        "2.  **Appropriate Abstraction**: Steps should offer high-level guidance rather than detailed instructions. They should be general enough to allow flexible execution.\n"
        "3.  **Comprehensiveness**: The plan should cover all major strategic aspects needed to solve the problem without being overly prescriptive.\n"
        "4.  **Clarity**: Instructions must be unambiguous at the strategic level while avoiding unnecessary detail.\n"
        "5.  **Conciseness**: The plan should be as brief as possible while remaining strategically complete.\n"
        "6.  **Feedback Adherence**: If history is provided, verify if the new plan has successfully addressed the previous strategic concerns.\n\n"
        "# Required Output Format\n"
        "1. Keep your entire response under 500 words.\n"
        "2. Provide your evaluation as a structured response with the following keys:\n"
        "    - `score`: A score from 0-100 based on the rubric (0-60: Unacceptable, 60-70: Major Issues, 70-80: Minor Issues, 80-100: Good).\n"
        "    - `feedback`: A brief, high-level explanation focusing on strategic clarity and appropriate abstraction.\n"
        "    - `suggestions`: A list of specific suggestions to improve strategic guidance while maintaining abstraction."
    )

    initial_plan_prompt: str = (
        "# Problem\n{problem}\n\n"
        "# Your Task\n"
        "Create a clear, strategic plan within 500 words to solve the provided problem. The plan must provide high-level guidance without diving into implementation details.\n\n"
        "# Plan Requirements\n"
        "-   **Strategic Steps**: Each step must be a high-level strategic action that guides toward the solution without specifying implementation details.\n"
        "-   **Clear Numbering**: Use proper numbering (1., 2., 3.) and sub-steps when needed (1.1, 1.2, etc.).\n"
        "-   **Logical Sequence**: Arrange steps in logical order where each step builds upon previous ones.\n"
        "-   **Appropriate Abstraction**: Keep steps general and abstract - avoid specific technical details or exact procedures.\n"
        "-   **Concise but Complete**: Include all necessary strategic elements while avoiding redundancy.\n"
        "-   **Global Guidance**: Focus on what needs to be accomplished rather than how to accomplish it.\n\n"
        "# Example Format\n"
        "1. [First strategic objective]\n"
        "2. [Second strategic objective]\n"
        "   2.1 [Strategic sub-objective if needed]\n"
        "   2.2 [Another strategic sub-objective if needed]\n"
        "3. [Third strategic objective]\n\n"
        "Your output should contain ONLY the numbered plan with no additional commentary or explanation. Keep it abstract and focused on strategic guidance."
    )

    initial_plan_prompt: str = (
        "# Problem\n{problem}\n\n"
        "# Your Task\n"
        "Create a clear, strategic plan within 500 words to solve the provided problem. The plan must provide high-level guidance without diving into implementation details.\n\n"
        "# Plan Requirements\n"
        "-   **Strategic Steps**: Each step must be a high-level strategic action that guides toward the solution without specifying implementation details.\n"
        "-   **Clear Numbering**: Use proper numbering (1., 2., 3.) and sub-steps when needed (1.1, 1.2, etc.).\n"
        "-   **Logical Sequence**: Arrange steps in logical order where each step builds upon previous ones.\n"
        "-   **Appropriate Abstraction**: Keep steps general and abstract - avoid specific technical details or exact procedures.\n"
        "-   **Concise but Complete**: Include all necessary strategic elements while avoiding redundancy.\n"
        "-   **Global Guidance**: Focus on what needs to be accomplished rather than how to accomplish it.\n\n"
        "# Example Format\n"
        "1. [First strategic objective]\n"
        "2. [Second strategic objective]\n"
        "   2.1 [Strategic sub-objective if needed]\n"
        "   2.2 [Another strategic sub-objective if needed]\n"
        "3. [Third strategic objective]\n\n"
        "Your output should contain ONLY the numbered plan with no additional commentary or explanation. Keep it abstract and focused on strategic guidance."
    )

    plan_modification_system: str = (
        "You are an expert strategic planner who creates improved, high-level plans based on feedback. "
        "Keep all responses under 500 words."
    )

    plan_modification_user: str = (
        "# Problem\n{problem}\n\n"
        "# Current Plan\n{plan}\n\n"
        "# Feedback to Address\n{feedback}\n\n"
        "# Your Task\n"
        "Create a significantly improved strategic plan within 500 words that addresses the identified weaknesses while maintaining appropriate abstraction and global guidance focus.\n\n"
        "**Improvement Guidelines:**\n"
        "1.  **Address Feedback**: Carefully review and directly address each piece of feedback provided.\n"
        "2.  **Maintain Abstraction**: Keep steps at a strategic level - avoid diving into implementation specifics or technical details.\n"
        "3.  **Improve Structure**: Ensure proper numbering (1., 2., 3.) and use sub-steps (1.1, 1.2) where appropriate.\n"
        "4.  **Optimize Flow**: Rearrange or modify steps to create a more logical strategic sequence.\n"
        "5.  **Eliminate Redundancy**: Remove unnecessary or duplicate steps while ensuring strategic completeness.\n"
        "6.  **Enhance Clarity**: Make strategic guidance clearer without adding unnecessary implementation detail.\n"
        "7.  **Global Focus**: Ensure the plan provides comprehensive strategic direction rather than step-by-step execution.\n\n"
        "# Required Output Format\n"
        "Provide your response as a structured object with two keys:\n"
        "-   `plan`: The full text of the new, improved strategic plan with proper numbering and high-level, abstract steps.\n"
        "-   `changes_made`: A detailed list of the specific structural and content changes you made to address the feedback while maintaining abstraction."
    )

    executor_system: str = (
        "You are a helpful AI assistant."
    )

    executor_user: str = (
        "# Problem\n{problem}\n\n"
        "# Plan to Execute\n{plan}\n\n"
        "Let's execute the plan step-by-step to solve the problem."
    )


class LLMConfig(BaseModel):
    api_key: str
    base_url: str
    model_name: str
    temperature: float = 0.7
    max_tokens: Optional[int] = None
    timeout: Optional[int] = None


class MCTSConfig(BaseModel):
    max_iterations: int = 8
    num_expansions: int = 2
    max_depth: int = 5
    exploration_weight: float = 1.0
    reward_threshold: float = 0.95
    reward_weights: Dict[str, float] = Field(default_factory=lambda: {
        "LogicalConsistency": 0.5,
        "Feasibility": 0.5,
    })

    def model_post_init(self, __context: Any) -> None:
        if self.reward_weights:
            weight_sum = sum(self.reward_weights.values())
            if abs(weight_sum - 1.0) > 1e-6:
                logger.warning(
                    f"Reward weights sum to {weight_sum:.6f}, not 1.0. This deviates from paper assumptions.")


class ExperimentConfig(BaseModel):
    problem_description: str
    llm_config: LLMConfig
    mcts_config: MCTSConfig
    prompts: PromptConfig = Field(default_factory=PromptConfig)

    @classmethod
    def create_default(cls, problem: str, llm_config: LLMConfig) -> 'ExperimentConfig':
        return cls(
            problem_description=problem,
            llm_config=llm_config,
            mcts_config=MCTSConfig(),
            prompts=PromptConfig()
        )


class EvaluationOutput(BaseModel):
    score: float = Field(ge=0.0, le=100.0, description="Evaluation score between 0 and 100")
    feedback: str = Field(description="Detailed textual feedback on the plan, explaining the score.")
    suggestions: List[str] = Field(default_factory=list, description="Specific suggestions for improvement.")


class ModifiedPlanOutput(BaseModel):
    plan: str = Field(description="The full text of the new, modified plan.")
    changes_made: List[str] = Field(description="A list of specific changes made to the original plan.")


class MCTSNode(BaseModel):
    node_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for the node")
    plan: str
    parent_id: Optional[str] = None
    depth: int = 0
    evaluations: Dict[str, EvaluationOutput] = Field(default_factory=dict,
                                                     description="Stores evaluation results from critic agents.")

    metadata: Optional[NodeMetadata] = None

    def is_terminal(self, mcts_config: MCTSConfig) -> bool:
        if self.depth >= mcts_config.max_depth:
            logger.debug(
                f"Node {self.node_id} is terminal: reached max depth ({self.depth} >= {mcts_config.max_depth}).")
            return True

        if self.evaluations:
            current_reward = self.get_reward(mcts_config.reward_weights)
            logger.debug(
                f"Node {self.node_id} at depth {self.depth}: current_reward={current_reward:.3f}, threshold={mcts_config.reward_threshold}")
            if current_reward >= mcts_config.reward_threshold:
                logger.info(
                    f"Node {self.node_id} is terminal: reward ({current_reward:.3f}) meets/exceeds threshold ({mcts_config.reward_threshold}).")
                return True

        logger.debug(
            f"Node {self.node_id} at depth {self.depth} is NOT terminal (reward: {self.get_reward(mcts_config.reward_weights):.3f})")
        return False

    def get_reward(self, weights: Dict[str, float]) -> float:
        if not self.evaluations:
            return 0.0

        raw_total_reward = 0.0
        for name, evaluation in self.evaluations.items():
            weight = weights.get(name, 0.0)
            raw_total_reward += evaluation.score * weight

        normalized_reward = raw_total_reward / 100.0

        logger.debug(
            f"Node {self.node_id}: calculated raw reward: {raw_total_reward:.3f}, normalized reward for MCTS: {normalized_reward:.3f}")
        return normalized_reward


class MCTSState(BaseModel):
    config: ExperimentConfig

    mcts_tree: Optional[Any] = None
    nodes: Dict[str, MCTSNode] = Field(default_factory=dict,
                                       description="A dictionary mapping node_id to MCTSNode objects.")

    iteration: int = 0

    best_plan: str = ""
    best_node_id: str = ""
    best_score: float = 0.0
    final_solution: Optional[str] = None

    search_records: List[SearchRecord] = Field(default_factory=list)

    class Config:
        arbitrary_types_allowed = True


class MCTS:
    def __init__(self, exploration_weight: float = 1.0):
        self.Q = defaultdict(float)
        self.N = defaultdict(int)
        self.children = defaultdict(set)
        self.exploration_weight = exploration_weight

        self.search_records = []
        self.current_iteration = 0

        logger.debug(f"MCTS object initialized with exploration_weight={exploration_weight}.")

    def _record_action(self, action: str, node_id: str, details: Dict[str, Any] = None):
        record = SearchRecord(
            iteration=self.current_iteration,
            timestamp=datetime.now(),
            action=action,
            node_id=node_id,
            details=details or {}
        )
        self.search_records.append(record)

    def select(self, root_node_id: str) -> List[str]:
        logger.debug("--- MCTS: Selection phase ---")
        path = [root_node_id]
        current_node_id = root_node_id

        self._record_action("select_start", root_node_id, {"path_length": 1})

        while current_node_id in self.children and self.children[current_node_id]:

            unexplored = self.children[current_node_id] - self.N.keys()
            if unexplored:
                new_leaf = unexplored.pop()
                path.append(new_leaf)
                logger.debug(f"Selected an unexplored child node: {new_leaf}")

                self._record_action("select_unexplored", new_leaf, {
                    "parent": current_node_id,
                    "path_length": len(path)
                })
                return path

            current_node_id = self._uct_select(current_node_id)
            path.append(current_node_id)

        self._record_action("select_complete", current_node_id, {
            "path_length": len(path),
            "final_depth": len(path) - 1
        })

        logger.debug(f"Selection path to leaf: {path}")
        return path

    def expand(self, parent_node_id: str, child_node_id: str) -> None:
        logger.debug(f"--- MCTS: Expansion phase ---")
        if child_node_id not in self.children[parent_node_id]:
            self.children[parent_node_id].add(child_node_id)
            logger.debug(f"Expanded node {parent_node_id} with new child {child_node_id}")

            self._record_action("expand", child_node_id, {
                "parent": parent_node_id,
                "children_count": len(self.children[parent_node_id])
            })

    def backpropagate(self, path: List[str], reward: float) -> None:
        logger.debug(f"--- MCTS: Backpropagation phase ---")
        logger.debug(f"Backpropagating reward {reward:.3f} along path of length {len(path)}.")

        self._record_action("backprop_start", path[-1], {
            "reward": reward,
            "path_length": len(path)
        })

        for i, node_id in enumerate(reversed(path)):
            old_n = self.N[node_id]
            old_q = self.Q[node_id]

            self.N[node_id] += 1
            self.Q[node_id] += reward

            logger.debug(f"  - Updated node {node_id}: N={self.N[node_id]}, Q={self.Q[node_id]:.3f}")

            if i < 5:
                self._record_action("backprop_update", node_id, {
                    "old_N": old_n,
                    "new_N": self.N[node_id],
                    "old_Q": old_q,
                    "new_Q": self.Q[node_id],
                    "avg_reward": self.Q[node_id] / self.N[node_id]
                })

    def _uct_select(self, parent_node_id: str) -> str:
        log_N_parent = math.log(self.N[parent_node_id])

        def uct(node_id: str) -> float:
            if self.N[node_id] == 0:
                return float("inf")

            average_reward = self.Q[node_id] / self.N[node_id]
            exploration_term = self.exploration_weight * math.sqrt(log_N_parent / self.N[node_id])
            uct_score = average_reward + exploration_term
            logger.debug(
                f"    - UCT for {node_id}: avg_reward={average_reward:.3f}, exploration={exploration_term:.3f}, total={uct_score:.3f}")
            return uct_score

        logger.debug(f"Calculating UCT scores for children of {parent_node_id}")
        selected = max(self.children[parent_node_id], key=uct)

        self._record_action("uct_select", selected, {
            "parent": parent_node_id,
            "num_children": len(self.children[parent_node_id]),
            "selected_score": uct(selected)
        })

        return selected

    def choose_best_node(self, root_node_id: str) -> str:
        if not self.N:
            logger.warning("MCTS tree has no nodes with visit counts. Returning root node.")
            return root_node_id

        candidate_nodes = list(self.N.keys())
        if not candidate_nodes:
            logger.warning("MCTS tree is empty. Returning root node.")
            return root_node_id

        def average_reward(node_id: str) -> float:
            if self.N.get(node_id, 0) == 0:
                return 0.0
            return self.Q.get(node_id, 0.0) / self.N.get(node_id, 1)

        logger.info("Choosing best node based on highest average reward...")
        for node_id in sorted(candidate_nodes, key=average_reward, reverse=True):
            logger.debug(
                f"  - Candidate: {node_id} | Avg Reward: {average_reward(node_id):.3f} | Visits: {self.N.get(node_id, 0)}")

        best_node_id = max(candidate_nodes, key=average_reward)

        self._record_action("select_best", best_node_id, {
            "num_candidates": len(candidate_nodes),
            "best_avg_reward": average_reward(best_node_id),
            "best_visits": self.N.get(best_node_id, 0)
        })

        return best_node_id

    def visualize_tree(self, root_node_id: str, nodes_dict: Dict[str, MCTSNode], max_width: int = 80) -> str:
        def truncate_plan(plan: str, width: int = max_width) -> str:
            if len(plan) <= width:
                return plan.replace('\n', ' | ')
            return plan.replace('\n', ' | ')[:width - 3] + "..."

        def get_node_info(node_id: str) -> str:
            visits = self.N.get(node_id, 0)
            total_reward = self.Q.get(node_id, 0.0)
            avg_reward = total_reward / visits if visits > 0 else 0.0

            node = nodes_dict.get(node_id, None)
            depth = node.depth if node else 0
            short_id = node_id[:8]

            return f"[{short_id} D:{depth} V:{visits} Q:{avg_reward:.3f}]"

        def build_tree_lines(node_id: str, prefix: str = "", is_last: bool = True) -> List[str]:
            lines = []

            node_info = get_node_info(node_id)
            node = nodes_dict.get(node_id)
            plan_text = truncate_plan(node.plan if node else "Unknown Plan",
                                      max_width - len(prefix) - len(node_info) - 5)
            connector = "└── " if is_last else "├── "
            lines.append(f"{prefix}{connector}{node_info} {plan_text}")

            children = sorted(list(self.children.get(node_id, [])))
            for i, child_node_id in enumerate(children):
                is_child_last = (i == len(children) - 1)
                child_prefix = prefix + ("    " if is_last else "│   ")
                lines.extend(build_tree_lines(child_node_id, child_prefix, is_child_last))

            return lines

        tree_lines = [
            f"MCTS Tree Visualization (Root: {get_node_info(root_node_id)} Best: {get_node_info(self.choose_best_node(root_node_id))})"]
        tree_lines.append("=" * max_width)
        tree_lines.extend(build_tree_lines(root_node_id))
        tree_lines.append("=" * max_width)
        tree_lines.append(f"Legend: [UUID D:depth V:visits Q:avg_reward]")
        tree_lines.append(
            f"Total nodes: {len(self.N)}, Total edges: {sum(len(children) for children in self.children.values())}")

        return "\n".join(tree_lines)

    def get_search_analytics(self) -> Dict[str, Any]:
        if not self.search_records:
            return {"error": "No search records available"}

        action_counts = {}
        for record in self.search_records:
            action_counts[record.action] = action_counts.get(record.action, 0) + 1

        start_time = self.search_records[0].timestamp if self.search_records else datetime.now()
        end_time = self.search_records[-1].timestamp if self.search_records else datetime.now()
        total_duration = (end_time - start_time).total_seconds()

        total_nodes = len(self.N)
        total_visits = sum(self.N.values())
        avg_reward = sum(self.Q.values()) / total_visits if total_visits > 0 else 0

        return {
            "search_summary": {
                "total_records": len(self.search_records),
                "total_duration_seconds": total_duration,
                "iterations": self.current_iteration
            },
            "action_breakdown": action_counts,
            "tree_statistics": {
                "total_nodes": total_nodes,
                "total_visits": total_visits,
                "average_reward": avg_reward,
                "max_depth": max([len(self.children.get(node_id, [])) for node_id in self.N.keys()]) if self.N else 0
            },
            "top_nodes": [
                {
                    "node_id": node_id,
                    "visits": self.N[node_id],
                    "avg_reward": self.Q[node_id] / self.N[node_id]
                }
                for node_id in
                sorted(self.N.keys(), key=lambda x: self.Q[x] / self.N[x] if self.N[x] > 0 else 0, reverse=True)[:5]
            ]
        }


class LLMEvaluator(ABC):
    @abstractmethod
    def evaluate(self, question_text: str, plan_text: str, history_section: str = "") -> EvaluationOutput:
        pass


class LogicalConsistencyAgent(LLMEvaluator):
    def __init__(self, llm: ChatOpenAI, prompts: PromptConfig):
        self.llm = llm.with_structured_output(EvaluationOutput)
        self.prompts = prompts

    @retry(
        wait=wait_random_exponential(multiplier=1, max=5),
        stop=stop_after_attempt(MAX_RETRIES)
    )
    def evaluate(self, question_text: str, plan_text: str, history_section: str = "") -> EvaluationOutput:
        logger.debug("Evaluating plan with LogicalConsistencyAgent...")
        system_prompt = self.prompts.logical_consistency_system
        user_prompt = self.prompts.logical_consistency_user.format(
            history_section=history_section,
            question=question_text,
            plan=plan_text
        )
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        try:
            result = self.llm.invoke(messages)
            logger.debug(f"LogicalConsistencyAgent score: {result.score:.2f}/100")
            return result
        except Exception as e:
            logger.error(f"Logical consistency evaluation failed: {e}")
            return EvaluationOutput(score=50.0, feedback="Failed to evaluate logical consistency.", suggestions=[])


class FeasibilityAgent(LLMEvaluator):
    def __init__(self, llm: ChatOpenAI, prompts: PromptConfig):
        self.llm = llm.with_structured_output(EvaluationOutput)
        self.prompts = prompts

    @retry(
        wait=wait_random_exponential(multiplier=1, max=5),
        stop=stop_after_attempt(MAX_RETRIES)
    )
    def evaluate(self, question_text: str, plan_text: str, history_section: str = "") -> EvaluationOutput:
        logger.debug("Evaluating plan with FeasibilityAgent...")
        system_prompt = self.prompts.feasibility_system
        user_prompt = self.prompts.feasibility_user.format(
            history_section=history_section,
            question=question_text,
            plan=plan_text
        )
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        try:
            result = self.llm.invoke(messages)
            logger.debug(f"FeasibilityAgent score: {result.score:.2f}/100")
            return result
        except Exception as e:
            logger.error(f"Feasibility evaluation failed: {e}")
            return EvaluationOutput(score=50.0, feedback="Failed to evaluate feasibility.", suggestions=[])


class PlanGenerator:
    def __init__(self, llm: ChatOpenAI, prompts: PromptConfig):
        self.llm = llm
        self.prompts = prompts

    @retry(
        wait=wait_random_exponential(multiplier=1, max=5),
        stop=stop_after_attempt(MAX_RETRIES)
    )
    def generate_initial_plan(self, problem: str) -> str:
        logger.info("Generating initial plan...")
        prompt = self.prompts.initial_plan_prompt.format(problem=problem)
        response = self.llm.invoke([{"role": "user", "content": prompt}])
        logger.debug(f"Generated initial plan:\n{response.content.strip()}")
        return response.content.strip()

    @retry(
        wait=wait_random_exponential(multiplier=1, max=5),
        stop=stop_after_attempt(MAX_RETRIES)
    )
    def modify_plan(self, problem: str, plan: str, evaluations: Dict[str, EvaluationOutput]) -> ModifiedPlanOutput:
        logger.debug(f"Modifying plan based on evaluation feedback...")
        feedback_text = "\n".join(
            [
                f"- {name} Feedback (Score: {eval.score}/100): {eval.feedback}\n  Suggestions: {', '.join(eval.suggestions)}"
                for name, eval in evaluations.items()])

        system_prompt = self.prompts.plan_modification_system
        user_prompt = self.prompts.plan_modification_user.format(
            problem=problem, plan=plan, feedback=feedback_text
        )
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        try:
            structured_llm = self.llm.with_structured_output(ModifiedPlanOutput)
            modified_plan = structured_llm.invoke(messages)
            logger.debug(f"Plan modified. Changes: {modified_plan.changes_made}")
            return modified_plan
        except Exception as e:
            logger.error(f"Plan modification failed: {e}")
            return ModifiedPlanOutput(plan=plan, changes_made=["No modifications made due to error"])


class PlanExecutor:
    def __init__(self, llm: ChatOpenAI, prompts: PromptConfig):
        self.llm = llm
        self.prompts = prompts

    @retry(
        wait=wait_random_exponential(multiplier=1, max=5),
        stop=stop_after_attempt(MAX_RETRIES)
    )
    def execute(self, problem: str, plan: str):
        logger.info("Executing final plan with PlanExecutor...")
        system_prompt = self.prompts.executor_system
        user_prompt = self.prompts.executor_user.format(problem=problem, plan=plan)
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        try:
            result = self.llm.invoke(messages)
            logger.info("Plan execution complete.")
            return result
        except Exception as e:
            logger.error(f"Plan execution failed: {e}")
            return "Execution failed."


def initialize_mcts(state: MCTSState) -> MCTSState:
    logger.info("--- (1/4) Initializing MCTS search ---")
    state.mcts_tree = MCTS(exploration_weight=state.config.mcts_config.exploration_weight)

    llm = ChatOpenAI(
        api_key=state.config.llm_config.api_key,
        base_url=state.config.llm_config.base_url,
        model=state.config.llm_config.model_name,
        temperature=state.config.llm_config.temperature,
    )
    generator = PlanGenerator(llm, state.config.prompts)
    initial_plan = generator.generate_initial_plan(state.config.problem_description)

    root_node = MCTSNode(
        plan=initial_plan,
        depth=0,
        metadata=NodeMetadata(
            creation_time=datetime.now(),
            creation_iteration=0,
            total_evaluations=0
        )
    )
    state.nodes[root_node.node_id] = root_node

    state.mcts_tree.N[root_node.node_id] = 1

    state.best_plan = initial_plan
    state.best_node_id = root_node.node_id
    state.best_score = 0.0
    logger.info(f"MCTS initialized with root node {root_node.node_id}.")
    return state


def mcts_iteration(state: MCTSState) -> MCTSState:
    logger.info(f"--- (2/4) MCTS search iteration {state.iteration + 1}/{state.config.mcts_config.max_iterations} ---")

    state.mcts_tree.current_iteration = state.iteration + 1

    llm = ChatOpenAI(
        api_key=state.config.llm_config.api_key,
        base_url=state.config.llm_config.base_url,
        model=state.config.llm_config.model_name,
        temperature=state.config.llm_config.temperature,
    )
    evaluators = {
        "LogicalConsistency": LogicalConsistencyAgent(llm, state.config.prompts),
        "Feasibility": FeasibilityAgent(llm, state.config.prompts),
    }

    mcts = state.mcts_tree
    root_node_id = state.best_node_id if state.best_node_id else list(state.nodes.keys())[0]

    path_to_leaf = mcts.select(root_node_id)
    leaf_node_id = path_to_leaf[-1]
    leaf_node = state.nodes[leaf_node_id]
    logger.info(f"Selected leaf node {leaf_node_id} at depth {leaf_node.depth}")

    if not leaf_node.evaluations:
        logger.info("Evaluating selected leaf node for the first time...")
        for name, evaluator in evaluators.items():
            leaf_node.evaluations[name] = evaluator.evaluate(state.config.problem_description, leaf_node.plan)

        if leaf_node.metadata:
            leaf_node.metadata.total_evaluations += len(evaluators)
            leaf_node.metadata.last_updated = datetime.now()

    if leaf_node.is_terminal(state.config.mcts_config):
        logger.info("Selected leaf node is terminal. No expansion will be performed.")

        reward = leaf_node.get_reward(state.config.mcts_config.reward_weights)
        mcts.backpropagate(path_to_leaf, reward)
        state.iteration += 1
        return state

    logger.info(f"Expanding leaf node with {state.config.mcts_config.num_expansions} new child plan(s)...")
    generator = PlanGenerator(llm, state.config.prompts)

    feedback_text = "\n\n".join(
        [
            f"Feedback from {name} (Score: {eval.score:.0f}/100):\n"
            f"Critique: {eval.feedback}\n"
            f"Suggestions: {'; '.join(eval.suggestions) if eval.suggestions else 'None'}"
            for name, eval in leaf_node.evaluations.items()
        ]
    )
    history_section = (
        "### CONTEXT: PLAN IMPROVEMENT HISTORY ###\n"
        "The plan you are about to evaluate is a revised version of a previous one. "
        "Your task is to assess if the revision successfully addressed the feedback.\n\n"
        f"**Previous Plan:**\n{leaf_node.plan}\n\n"
        f"**Feedback on Previous Plan:**\n{feedback_text}\n\n"
        "-----------------------------------\n\n"
    )

    for i in range(state.config.mcts_config.num_expansions):
        modified_output = generator.modify_plan(state.config.problem_description, leaf_node.plan, leaf_node.evaluations)
        child_plan_text = modified_output.plan

        existing_node_id = None
        for node_id, node in state.nodes.items():
            if node.plan == child_plan_text:
                existing_node_id = node_id
                break

        if existing_node_id:
            logger.info("  - Generated a plan that already exists in the tree. Skipping.")
            continue

        child_node = MCTSNode(
            plan=child_plan_text,
            parent_id=leaf_node_id,
            depth=leaf_node.depth + 1,
            metadata=NodeMetadata(
                creation_time=datetime.now(),
                creation_iteration=state.iteration + 1,
                total_evaluations=0
            )
        )
        state.nodes[child_node.node_id] = child_node
        mcts.expand(leaf_node_id, child_node.node_id)
        logger.info(f"  - New child {child_node.node_id} (depth {child_node.depth}) created")

        logger.info(f"  - Evaluating new child node with history context...")
        for name, evaluator in evaluators.items():
            child_node.evaluations[name] = evaluator.evaluate(
                state.config.problem_description, child_node.plan, history_section=history_section
            )

        if child_node.metadata:
            child_node.metadata.total_evaluations += len(evaluators)
            child_node.metadata.last_updated = datetime.now()

        reward = child_node.get_reward(state.config.mcts_config.reward_weights)
        logger.info(f"  - Evaluated new child with combined normalized reward: {reward:.3f}. Backpropagating...")
        mcts.backpropagate(path_to_leaf + [child_node.node_id], reward)

        if reward > state.best_score:
            logger.info(f"  - New best plan found! Score: {reward:.3f} > {state.best_score:.3f}")
            state.best_score = reward
            state.best_plan = child_node.plan
            state.best_node_id = child_node.node_id

    state.search_records.extend(mcts.search_records)
    mcts.search_records = []

    state.iteration += 1
    return state


def select_best_plan(state: MCTSState) -> MCTSState:
    logger.info("--- (3/4) Selecting best final plan from the search tree ---")
    root_node_id = list(state.nodes.keys())[0]
    best_node_id = state.mcts_tree.choose_best_node(root_node_id)

    if best_node_id and best_node_id in state.nodes:
        best_node = state.nodes[best_node_id]
        final_value = best_node.get_reward(state.config.mcts_config.reward_weights)
        state.best_plan = best_node.plan
        state.best_node_id = best_node_id
        state.best_score = final_value
        avg_reward = state.mcts_tree.Q.get(best_node_id, 0.0) / max(state.mcts_tree.N.get(best_node_id, 1), 1)
        visits = state.mcts_tree.N.get(best_node_id, 0)
        logger.info(f"Best plan selected: {best_node_id}")
        logger.info(f"  - Average Normalized Reward (Q/N): {avg_reward:.3f}")
        logger.info(f"  - Visit Count (N): {visits}")
    else:
        logger.warning("Could not identify a best plan. Using the initial plan as a fallback.")
        root_node_id = list(state.nodes.keys())[0]
        state.best_node_id = root_node_id
        state.best_plan = state.nodes[root_node_id].plan

    state.search_records.extend(state.mcts_tree.search_records)

    return state


def execute_plan(state: MCTSState) -> MCTSState:
    logger.info("--- (4/4) Executing the best plan ---")
    llm = ChatOpenAI(
        api_key=state.config.llm_config.api_key,
        base_url=state.config.llm_config.base_url,
        model=state.config.llm_config.model_name,
        temperature=state.config.llm_config.temperature,
    )
    executor = PlanExecutor(llm, state.config.prompts)
    result = executor.execute(state.config.problem_description, state.best_plan)
    state.final_solution = result.content if hasattr(result, 'content') else str(result)
    return state


def should_continue(state: MCTSState) -> str:
    logger.info(
        f"Checking continuation condition: iteration {state.iteration}/{state.config.mcts_config.max_iterations}, "
        f"best_score {state.best_score:.3f}/{state.config.mcts_config.reward_threshold}")
    if state.iteration >= state.config.mcts_config.max_iterations:
        logger.info("Condition met: Maximum iterations reached. Ending search.")
        return "select_best"
    if state.best_score >= state.config.mcts_config.reward_threshold:
        logger.info(f"Condition met: High-scoring plan found. Ending search early.")
        return "select_best"
    logger.info("Condition not met. Continuing search.")
    return "search"


def build_mcts_workflow() -> CompiledStateGraph:
    workflow = StateGraph(MCTSState)

    workflow.add_node("initialize", initialize_mcts)
    workflow.add_node("search", mcts_iteration)
    workflow.add_node("select_best", select_best_plan)
    workflow.add_node("execute", execute_plan)

    workflow.set_entry_point("initialize")
    workflow.add_edge("initialize", "search")

    workflow.add_conditional_edges(
        "search",
        should_continue,
        {
            "search": "search",
            "select_best": "select_best",
        }
    )
    workflow.add_edge("select_best", "execute")
    workflow.add_edge("execute", END)

    logger.info("Compiling MCTS workflow graph...")
    return workflow.compile()


def main():
    llm_config = LLMConfig(
        api_key="API KEY HEER",
        base_url="URL HEER",
        model_name="MODEL NAME HEER",
        temperature=0.7,
        max_tokens=1024,
    )

    mcts_config = MCTSConfig(
        max_iterations=3,
        num_expansions=2,
        max_depth=3,
        exploration_weight=1.0,
        reward_threshold=0.95,
        reward_weights={
            "LogicalConsistency": 0.5,
            "Feasibility": 0.5,
        }
    )

    config = ExperimentConfig(
        problem_description=r"Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$",
        llm_config=llm_config,
        mcts_config=mcts_config,
        prompts=PromptConfig()
    )

    initial_state = MCTSState(config=config)

    workflow = build_mcts_workflow()

    print("=" * 60)
    print(f"Problem: {config.problem_description}")
    print("=" * 60)

    print("Running MCTS search...")
    final_state = workflow.invoke(initial_state)
    print("Workflow finished.")

    print("=" * 60)
    print(f"Best Plan Found:\n{final_state['best_plan']}")
    print(f"Best Node ID: {final_state['best_node_id']}")
    print(f"\nFinal Plan Score (Normalized): {final_state['best_score']:.3f}")
    print("=" * 60)
    print(f"Final Solution:\n{final_state['final_solution']}")
    print("=" * 60)

    print_search_analytics(final_state)
    print_tree_visualization(final_state)
    print_detailed_node_analysis(final_state)


if __name__ == "__main__":
    main()
