import argparse
import concurrent.futures
import datetime
import heapq
import json
import logging
import os
import random
import re
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import ale_bench
import graphviz
import numpy as np
from PIL import Image
from ale_bench.data import ScoreType, list_problem_ids
from ale_bench.result import CaseResult, JudgeResult, ResourceUsage, Result
from ale_bench.session import Session
from ale_bench.tool_wrappers.case_runner import run_cases
from ale_bench.utils import parse_statement
from google import genai
from google.genai import types as genai_types
from google.genai.chats import Chat

# Import prompt templates
from prompt_templates import (
    DEFAULT_IMPROVEMENT_GUIDANCE,
    FEEDBACK_WITH_SUMMARY_TEMPLATE,
    get_improvement_guidance_with_weights,
    get_message_templates_with_weights,
)


# Logging settings
def setup_logging(problem_id=None):
    """Set up logging"""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Clear existing handlers
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # Handler for console output
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(console_format)
    logger.addHandler(console_handler)

    # Handler for file output (only if problem_id is specified)
    if problem_id:
        log_path = Path("results") / problem_id / "log.txt"
        log_path.parent.mkdir(parents=True, exist_ok=True)
        file_handler = logging.FileHandler(log_path, mode="w")
        file_handler.setLevel(logging.INFO)
        file_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
        file_handler.setFormatter(file_format)
        logger.addHandler(file_handler)

    return logger


# Initialize default logger
logger = setup_logging()

LANG_EXPR = {
    "cpp17": "C++17",
    "cpp20": "C++20",
    "cpp23": "C++23",
    "python": "Python",
    "rust": "Rust",
}
LANG_FILE = {
    "cpp17": "cpp",
    "cpp20": "cpp",
    "cpp23": "cpp",
    "python": "py",
    "rust": "rs",
}
CODE_BLOCK = {
    "cpp17": "```cpp\\n// Your code here\\n```",
    "cpp20": "```cpp\\n// Your code here\\n```",
    "cpp23": "```cpp\\n// Your code here\\n```",
    "python": "```python\\n# Your code here\\n```",
    "rust": "```rust\\n// Your code here\\n```",
}
CODE_BLOCK_FORMAT = {
    "cpp17": "```cpp\\n{}\\n```",
    "cpp20": "```cpp\\n{}\\n```",
    "cpp23": "```cpp\\n{}\\n```",
    "python": "```python\\n{}\\n```",
    "rust": "```rust\\n{}\\n```",
}
CODE_BLOCK_MATCH = {
    "cpp17": re.compile(r"```cpp\n(.+?)\n```", re.DOTALL),
    "cpp20": re.compile(r"```cpp\n(.+?)\n```", re.DOTALL),
    "cpp23": re.compile(r"```cpp\n(.+?)\n```", re.DOTALL),
    "python": re.compile(r"```python\n(.+?)\n```", re.DOTALL),
    "rust": re.compile(r"```rust\n(.+?)\n```", re.DOTALL),
}
SUPPORTED_MODELS = [
    "gemini-2.5-pro-preview-03-25",
    "gemini-2.5-pro-exp-03-25",
    "gemini-2.5-pro-preview-05-06",
    "gemini-2.5-flash-preview-04-17",
    "gemini-2.0-flash-lite",
    "gemini-2.0-flash-thinking-exp",
    "gemini-2.0-flash",
    "gemini-1.5-flash",
]

PATH_RESULTS = Path("results")
# PROMPT_ENFORCE = "**You must implement the solution using simulated annealing.**"
# PROMPT_ENFORCE = "**You must implement the solution using beam search.**"
PROMPT_ENFORCE = ""


class GLOBAL:
    problem_id: str = ""
    lite_version: bool = False
    argument_code_language: str = "cpp23"
    path_session: Path | None = None
    ale_bench_session: Session | None = None
    llm_model: str | None = None
    llm_client: genai.Client | None = None
    problem_statement: list[dict[str, str | dict[str, str]]] | None = None
    score_type: ScoreType | None = None
    time_limit: float | None = None
    worst_score: float | None = None
    verbose: bool = False
    realtime: bool = False  # Add realtime flag
    break_on_AC: bool = True  # Add: Whether to stop generation when AC is achieved
    use_domain_knowledge: bool = False  # Add: Whether to use domain knowledge prompts

    num_states: int = 0
    num_code_patience: int = 3
    num_cases_exp: int = 50
    list_cases_exp: list[str] = []
    num_workers: int = 16

    # Number of parallel expansions (threads) for child node generation
    num_expansion_threads: int = 100

    max_performance: float = 5000.0

    # Add locks for session management and state index
    session_lock: threading.Lock = threading.Lock()
    state_lock: threading.Lock = threading.Lock()
    tree_lock: threading.Lock = threading.Lock()  # Lock for tree modifications

    # Variables and lock for API rate limiting
    api_rate_limit_lock: threading.Lock = threading.Lock()
    last_api_call_time: float = 0.0
    api_call_interval: float = 0.1  # Minimum 0.1 second interval

    # Job queue related for parallel tree search
    job_queue_lock: threading.Lock = threading.Lock()
    public_eval_jobs: List[Tuple[str, int, int]] = []  # (code, tree_id, step)
    public_eval_results: dict = {}  # key: (tree_id, step, code_hash), value: Result

    # Execution time limit
    end_time: Optional[float] = None

    # Variables for summary feature
    codes_history_lock: threading.Lock = threading.Lock()
    # History for each tree ID: dict[tree_id, list[tuple[Result, str, str, str]]]
    # tuple elements: (result, code_language, code, feedback)
    codes_history: Dict[int, List[Tuple[Optional[Result], str, str, str]]] = {}
    summary_history: Dict[int, List[Optional[str]]] = {}  # {tree_id: [summaries]}
    best_codes: Dict[int, Tuple[Optional[Result], str, str]] = {}  # {tree_id: (best_result, code_language, best_code)}
    use_summary: bool = True  # Whether to use the summary feature


class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Image.Image):
            return str(obj)
        elif isinstance(obj, bytes):
            # return str(obj)
            return "binary data"
        return super().default(obj)


def reset_ale_bench_session():
    """Reset the ALE bench session with thread safety
    Should be called inside GLOBAL.session_lock
    """
    GLOBAL.ale_bench_session = ale_bench.start(
        problem_id=GLOBAL.problem_id,
        lite_version=GLOBAL.lite_version,
        use_same_time_scale=False,
        maximum_num_case_gen=int(1e18),
        maximum_num_case_eval=int(1e18),
        maximum_execution_time_case_eval=1e18,
        maximum_num_call_public_eval=int(1e18),
        session_duration=1e9,
        num_workers=GLOBAL.num_workers,
        run_visualization_server=False,
    )


def update_global_variables(
    problem_id: str, lite_version: bool, argument_code_language: str, model: str, num_workers: int, verbose: bool, realtime: bool, break_on_ac: bool
) -> None:
    GLOBAL.problem_id = problem_id
    GLOBAL.lite_version = lite_version
    GLOBAL.argument_code_language = argument_code_language
    GLOBAL.path_session = PATH_RESULTS / problem_id
    GLOBAL.num_workers = num_workers
    GLOBAL.realtime = realtime  # Update realtime flag
    GLOBAL.break_on_AC = break_on_ac

    # If problem_id is set, update logging
    global logger
    logger = setup_logging(problem_id)

    logger.info(f"Updating global variables with\\n\\t{problem_id=}\\n\\t{lite_version=}\\n\\t{argument_code_language=}\\n\\t{model=}\\n\\t{num_workers=}\\n\\t{realtime=}")

    with GLOBAL.session_lock:
        reset_ale_bench_session()
        GLOBAL.llm_model = model
        GLOBAL.llm_client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
        GLOBAL.problem_statement = parse_statement(
            GLOBAL.ale_bench_session.problem.statement,
            GLOBAL.ale_bench_session.problem.statement_images,
        )
        GLOBAL.score_type = GLOBAL.ale_bench_session.problem.metadata.score_type.value
        GLOBAL.time_limit = GLOBAL.ale_bench_session.problem.constraints.time_limit
        GLOBAL.worst_score = 0 if GLOBAL.score_type == ScoreType.MAXIMIZE else float("inf")
        GLOBAL.best_normalized_score = 1
        GLOBAL.verbose = verbose

        GLOBAL.num_states = 0

        # Generate experimental cases with thread safety
        GLOBAL.list_cases_exp = GLOBAL.ale_bench_session.case_gen(
            seed=list(range(GLOBAL.num_cases_exp)),
        )


def get_case_exp(seed: int | None = None):
    if seed is None:
        seed = np.random.randint(0, GLOBAL.num_cases_exp)
    else:
        if seed < 0 or seed >= GLOBAL.num_cases_exp:
            raise ValueError(f"Seed must be between 0 and {GLOBAL.num_cases_exp - 1}")
    return GLOBAL.list_cases_exp[seed]


def get_chat(
    system_instruction: str = "You are a world-class algorithm engineer and have been tasked with solving a heuristics problem. "
    f"You are surprisingly good at {LANG_EXPR[GLOBAL.argument_code_language]} programming. Now, you are a competitor in a programming contest.",
) -> Chat:
    logger.info(f"Start chat with {GLOBAL.llm_model=}")
    return GLOBAL.llm_client.chats.create(
        model=GLOBAL.llm_model,
        config=genai_types.GenerateContentConfig(
            system_instruction=system_instruction,
            response_modalities=["Text"],
        ),
    )


def get_code(response_text: str) -> str | None:
    try:
        code_matches = CODE_BLOCK_MATCH[GLOBAL.argument_code_language].findall(response_text)
        if len(code_matches) == 0:
            return None
        return code_matches[-1]
    except Exception as e:
        logger.error(f"Ignoring error in get_code and returning None: {e}")
        return None


def is_first_result_better(result_left: Result, result_right: Result) -> bool:
    score_left = result_left.overall_absolute_score
    score_right = result_right.overall_absolute_score
    if GLOBAL.score_type == ScoreType.MAXIMIZE:
        return score_left > score_right
    else:
        return score_left < score_right


def case_result_feedback(case_result: CaseResult) -> str:
    return f"""[Case result]
Judge: {case_result.judge_result.value}
Absolute score: {case_result.absolute_score}
Execution time: {case_result.execution_time:.3f} sec
Memory usage: {case_result.memory_usage // 1024 // 1024} MiB
Message: {case_result.message}"""


# Ported result_feedback function from common_resource.py
def result_feedback(result: Result) -> str:
    if result is None:
        return "No feedback available."
    feedback = f"[Public test result]\\nOverall judge result: {result.overall_judge_result.value}\\n"
    if result.overall_judge_result == JudgeResult.ACCEPTED:
        feedback += f"Overall absolute score: {result.overall_absolute_score}\\n"
        feedback += "\\n".join(
            [f"- Case {i}: {case_result.absolute_score}" for i, case_result in enumerate(result.case_results, 1)]
        )
    else:
        selected_case_idx = 0
        for idx, case_result in enumerate(result.case_results):
            if case_result.judge_result == result.overall_judge_result:
                selected_case_idx = idx
                break
        # case_result_feedback already exists in run.py
        feedback += case_result_feedback(result.case_results[selected_case_idx])
    return feedback


def feedback_message(public_result: Result) -> str:
    feedback = None
    if public_result.overall_judge_result == JudgeResult.ACCEPTED:
        contents = (
            f"{result_feedback(public_result)}\\n\\nBased on the above test result, please improve your code to achieve better performance. "
            f"The code should be implemented in the {CODE_BLOCK[GLOBAL.argument_code_language]} code block. "
        )
    else:
        for case_result in public_result.case_results:
            if case_result.judge_result == public_result.overall_judge_result:
                feedback = case_result_feedback(case_result=case_result)
                break
        if feedback is None:
            raise ValueError("No case result found")
        contents = (
            f"{feedback}\\n\\nBased on the above test result, please fix the code to make it accepted. "
            "Please analyze the error message and fix the code accordingly. "
            f"The code should be implemented in the {CODE_BLOCK[GLOBAL.argument_code_language]} code block. "
            "Make sure to address the specific issue mentioned in the error message."
        )
    return contents


@dataclass(order=True)
class SearchNode:
    """Class representing a node in the search tree"""

    priority: Tuple[float, float, int] = field(init=False, repr=False)  # Priority (is_accepted, score, depth)
    score: float = field(compare=False)
    depth: int = field(compare=False)
    state: Any = field(compare=False)  # The state of the search (an instance of State class)
    parent: Optional["SearchNode"] = field(default=None, compare=False)
    children: List["SearchNode"] = field(default_factory=list, compare=False)
    expand_idx: int = field(default=-1, compare=False)  # Node expansion order index

    def __post_init__(self):
        # Change priority to a tuple of three elements
        # 1. Whether the node is AC (False if AC, True otherwise - reverse because heapq prioritizes minimum)
        # 2. Score (negative value for maximization problems, positive value for minimization problems)
        # 3. Depth (prioritize shallower nodes)

        # Determine if the node is AC (add method to judge from public evaluation results)
        ac_priority = -self.get_ac_ratio()

        # Score priority calculation (as before)
        if GLOBAL.score_type == ScoreType.MAXIMIZE:
            # For maximization problems, higher score is better, so use negative score as priority
            score_priority = -self.score
        else:
            # For minimization problems, lower score is better, so use positive score as priority
            score_priority = self.score

        # Set priority tuple
        self.priority = (ac_priority, score_priority, self.depth)

    def add_child(self, child_node: "SearchNode"):
        self.children.append(child_node)
        child_node.parent = self

    def get_ac_ratio(self):
        if self.state and hasattr(self.state, "public_result") and self.state.public_result:
            num_cases = len(self.state.public_result.case_results)
            if num_cases > 0:
                num_ac = sum(
                    1 for case in self.state.public_result.case_results if case.judge_result == JudgeResult.ACCEPTED
                )
                return num_ac / num_cases
        return 0.0

class SearchTree:
    """Class managing the entire search tree"""

    def __init__(self, root_state: Any, initial_score: float):
        # Allow None as root node
        self.root = SearchNode(score=initial_score, depth=0, state=root_state, expand_idx=0)
        self.nodes: List[SearchNode] = [self.root]
        self.expand_counter = 1  # Keep track of max expand_idx used
        self.lock = threading.Lock()  # Add lock for thread safety

    def add_node(self, state: Any, score: float, parent_node: SearchNode) -> SearchNode:
        with self.lock:  # Protect tree modification
            # Use state.idx as expand_idx if state is not None, otherwise use expand_counter
            expand_idx = state.idx if state is not None else self.expand_counter

            new_node = SearchNode(
                score=score,
                depth=parent_node.depth + 1,
                state=state,
                parent=parent_node,
                expand_idx=expand_idx,
            )

            # Assert that state.idx matches expand_idx if state is not None
            if state is not None:
                assert expand_idx == state.idx, f"expand_idx {expand_idx} does not match State idx {state.idx}"

            parent_node.add_child(new_node)
            self.nodes.append(new_node)

            # Update expand_counter to keep track of max expand_idx used
            self.expand_counter = max(self.expand_counter, expand_idx + 1)

            return new_node

    def get_nodes(self) -> List[SearchNode]:
        with self.lock:  # Protect reads as well
            return self.nodes.copy()  # Return a copy for thread safety


class BestFirstSearch:
    """Implementation of Best First Search algorithm"""

    def __init__(self, num_samples: int = 1, num_parallel: int = GLOBAL.num_expansion_threads):
        self.num_samples = num_samples  # Number of child nodes to generate per expansion
        self.num_parallel = num_parallel  # Number of parallel threads for expansion
        self.priority_queue: List[SearchNode] = []  # Priority queue holding candidate nodes
        self.priority_queue_lock = threading.Lock()  # Lock for queue access
        self.tree: Optional[SearchTree] = None
        self.expand_history: List[int] = []  # For debugging: indices of expanded nodes
        self.expand_history_lock = threading.Lock()  # Lock for expand history
        # Remove next_state_idx_lock as we now use State.idx instead of allocating from tree

    def initialize(self, initial_state: Any, initial_score: float):
        self.tree = SearchTree(root_state=initial_state, initial_score=initial_score)
        with self.priority_queue_lock:
            heapq.heappush(self.priority_queue, self.tree.root)

    def _expand_node_worker(self, current_node: SearchNode, generate_action_fn: callable) -> None:
        """Worker function to expand a node in parallel"""
        try:
            # Generate a new state using the provided function
            # No need to pass next_idx as we now use State's own idx
            new_state, new_score = generate_action_fn(current_node.state)

            # Only add valid states
            if new_state is not None:
                # Add the new node to the tree
                new_node = self.tree.add_node(new_state, new_score, current_node)

                # Add to priority queue with thread safety
                with self.priority_queue_lock:
                    heapq.heappush(self.priority_queue, new_node)

                if GLOBAL.verbose:
                    score_quality = ""
                    if GLOBAL.score_type == ScoreType.MAXIMIZE:
                        score_quality = "higher is better"
                    else:
                        score_quality = "lower is better"
                    logger.info(
                        f"    Created child node with State Idx: {new_state.idx}, Score: {new_score:.4f} ({score_quality})"
                    )
        except Exception as e:
            logger.error(f"Error in node expansion worker: {e}")
            import traceback

            traceback.print_exc()

    def evaluate_node(self, node: SearchNode) -> None:
        """Evaluate a node using private evaluation and update its state with results"""
        if node.state is None:
            logger.info(f"  Node {node.expand_idx} has no state, skipping evaluation")
            return

        if node.state.is_root:
            logger.info(f"  Node {node.expand_idx} is a root node, skipping evaluation")
            return

        if node.state.code is None:
            logger.info(f"  Node {node.expand_idx} has no code, skipping evaluation")
            return

        logger.info(f"  Evaluating node {node.expand_idx} (State Idx: {node.state.idx})")

        # Get job queue
        job_queue = get_job_queue()

        try:
            # Create private evaluation job
            job = PrivateEvalJob(
                code=node.state.code,
                tree_id=node.state.tree_id,
                step=node.state.step,
                state_idx=node.state.idx,
                node_expand_idx=node.expand_idx,
            )

            # Add private evaluation job to queue and wait for result
            private_result, rank, performance = job_queue.add_private_job_and_wait(job)

            # Update the state with private evaluation results
            node.state.private_absolute_score = private_result.overall_absolute_score
            node.state.relative_score = private_result.overall_relative_score
            node.state.rank = rank
            node.state.performance = performance

            # Save the updated state
            save_state(node.state)

        except Exception as e:
            logger.error(f"  Error during node evaluation: {e}")
            import traceback

            traceback.print_exc()


def visualize_search_tree(
    tree: SearchTree,
    save_path: str,
    show_scores: bool = True,
    max_label_length: int = 30,  # Adjust label length
    title: Optional[str] = None,
    format: str = "png",
):
    """
    Function to visualize the search tree using Graphviz
    """
    dot = graphviz.Digraph(comment=title or "Search Tree Visualization")
    dot.attr(rankdir="TB")  # Draw from top to bottom

    if title:
        dot.attr(label=title)

    nodes = tree.get_nodes()

    for node in nodes:
        node_id = str(node.expand_idx)

        # Create label
        if node.parent is None:  # Root node
            label = f"""ROOT
ExpandIdx: {node.expand_idx}"""
            if show_scores:
                label += f"""
Score: {node.score:.4f}"""  # Adjust score display format
            color = "lightgray"
        else:
            # Create label based on State object content (using State's idx here)
            # Special handling if node.state is None
            if node.state is None:
                state_label = "None"
            else:
                state_label = f"State Idx: {node.state.idx}" if hasattr(node.state, "idx") else str(node.state)

            if len(state_label) > max_label_length:
                state_label = state_label[:max_label_length] + "..."

            label = f"""{state_label}
ExpandIdx: {node.expand_idx}"""
            if show_scores:
                label += f"""
Score: {node.score:.4f}"""  # Adjust score display format

                # Add relative score and rank information
                if hasattr(node.state, "relative_score") and node.state.relative_score is not None:
                    label += f"""
Relative Score: {node.state.relative_score:.4f}"""

                if hasattr(node.state, "rank") and node.state.rank is not None:
                    label += f"""
Rank: {node.state.rank}"""

                # Add display of judgment result
                if node.state and hasattr(node.state, "public_result") and node.state.public_result:
                    label += f"""
Judge: {node.state.public_result.overall_judge_result.value}"""

                # Display AC case ratio
                if hasattr(node.state, "public_result") and node.state.public_result:
                    label += f"""
AC Ratio: {node.get_ac_ratio():.2f}"""

            # Determine color based on whether it's AC
            is_ac = False
            if node.state:
                if hasattr(node.state, "public_result") and node.state.public_result:
                    # Judge AC from public_result
                    is_ac = node.state.public_result.overall_judge_result == JudgeResult.ACCEPTED

                if is_ac:
                    color = "lightgreen"  # AC
                elif hasattr(node.state, "absolute_score") and node.state.absolute_score is not None:
                    color = "lightcoral"  # Not AC but evaluated
                else:
                    color = "lightgray"  # Not evaluated
            else:
                # Nodes without state are gray
                color = "lightgray"

        dot.node(node_id, label=label, style="filled", fillcolor=color, shape="box")  # Change node shape

        # Add edge to parent node
        if node.parent:
            parent_id = str(node.parent.expand_idx)
            dot.edge(parent_id, node_id)

    # Render and save
    try:
        dot.render(save_path, format=format, cleanup=True, view=False)  # view=False to not display
        logger.info(f"Tree visualization saved to {save_path}.{format}")
    except Exception as e:
        logger.error(f"Error rendering graphviz: {e}")
        logger.error("Please ensure graphviz is installed and configured correctly.")


@dataclass
class State:
    """Information for each state (alternative to treequest.State)"""

    parent_state: Optional["State"] = None
    code: str | None = None
    score: float = GLOBAL.worst_score  # This score is used for search algorithm prioritization
    chat_history: list[dict[str, Any]] | None = None
    idx: int = -1
    # Additional information (absolute score, performance, etc.) as needed
    absolute_score: float | None = None
    feedback: str | None = None  # <<< Added
    is_root: bool = False  # Flag to indicate if this is a root state
    # Add private evaluation results
    private_absolute_score: float | None = None
    relative_score: float | None = None
    rank: int | None = None
    performance: float | None = None
    # Add search tree information
    tree_id: int = 0  # Tree ID
    step: int = 0  # Step
    # Add Result object
    public_result: Result | None = None  # Public evaluation result object

    def __post_init__(self):
        # Thread-safe way to get a unique state index
        if self.idx == -1:  # Only set if not explicitly provided
            with GLOBAL.state_lock:
                self.idx = GLOBAL.num_states
                GLOBAL.num_states += 1


def save_state(state: State) -> None:
    path_save = GLOBAL.path_session / "states" / f"{state.idx:03d}.json"
    path_save.parent.mkdir(parents=True, exist_ok=True)
    state_dict = {
        # "parent_idx": state.parent_state.idx if state.parent_state else None, # Parent info is in SearchNode so maybe unnecessary
        "score": state.score,  # Search score
        "code": state.code,
        "absolute_score": state.absolute_score,  # Additional info
        "feedback": state.feedback,  # <<< Added
        "idx": state.idx,
        "private_absolute_score": state.private_absolute_score,
        "relative_score": state.relative_score,
        "rank": state.rank,
        "performance": state.performance,
        "public_result_summary": None
        if state.public_result is None
        else {
            "judge_result": state.public_result.overall_judge_result.value,
            "absolute_score": state.public_result.overall_absolute_score,
            "case_count": len(state.public_result.case_results) if state.public_result.case_results else 0,
        },
    }
    with open(path_save, "w") as f:
        json.dump(state_dict, f, indent=4)


def save_history(chat_history_or_chat: list[dict[str, Any]] | Chat, idx_state: int) -> None:
    if isinstance(chat_history_or_chat, Chat):
        chat_history = [content.model_dump() for content in chat_history_or_chat.get_history()]
    else:
        chat_history = chat_history_or_chat
    path_save_dir = GLOBAL.path_session / "history" / f"{idx_state:03d}"
    path_save_dir.mkdir(parents=True, exist_ok=True)
    for i, content in enumerate(chat_history):
        code = None
        with open(path_save_dir / f"message_{i:03d}.txt", "w") as f:
            if "parts" in content and content["parts"] is not None:
                for part in content["parts"]:
                    if "text" in part and part["text"] is not None:
                        f.write(part["text"])
                        if code is None:
                            code = get_code(part["text"])
        if code is not None:
            with open(path_save_dir / f"message_{i:03d}.{LANG_FILE[GLOBAL.argument_code_language]}", "w") as f:
                f.write(code)
        with open(path_save_dir / f"message_{i:03d}.json", "w") as f:
            json.dump(content, f, indent=4, cls=CustomJSONEncoder)
        if (
            "parts" in content
            and content["parts"] is not None
            and any(part["inline_data"] is not None for part in content["parts"])
        ):
            idx_img = 0
            for part in content["parts"]:
                if part["inline_data"] is not None:
                    if part["inline_data"]["mime_type"] == "image/png":
                        with open(path_save_dir / f"message_{i:03d}_{idx_img:03d}.png", "wb") as f:
                            f.write(part["inline_data"]["data"])
                        idx_img += 1


def get_chat_history(llm_chat: Chat) -> list[dict[str, Any]]:
    return [content.model_dump() for content in llm_chat.get_history()]


# Function to implement API rate limiting
def rate_limit_api_call():
    """Function to implement API call rate limiting. Sleeps if necessary.
    Image does not work correctly unless API calls are spaced out."""
    with GLOBAL.api_rate_limit_lock:
        current_time = time.time()
        elapsed_time = current_time - GLOBAL.last_api_call_time

        # Wait if the specified interval has not elapsed since the last API call
        if elapsed_time < GLOBAL.api_call_interval:
            sleep_time = GLOBAL.api_call_interval - elapsed_time
            time.sleep(sleep_time)

        # Update the time of the last API call
        GLOBAL.last_api_call_time = time.time()


@dataclass
class EvalJob:
    """Class representing a public evaluation job"""

    code: str
    tree_id: int
    step: int
    state_idx: int
    code_hash: str = field(init=False)
    priority: int = field(init=False)

    def __post_init__(self):
        self.code_hash = str(hash(self.code))
        # Priority is determined by (step, tree_id)
        # Smaller step, smaller tree number is prioritized
        self.priority = self.step * 1000 + self.tree_id

    def __lt__(self, other):
        return self.priority < other.priority


@dataclass
class PrivateEvalJob:
    """Class representing a private evaluation job"""

    code: str
    tree_id: int
    step: int
    state_idx: int
    node_expand_idx: int
    code_hash: str = field(init=False)
    priority: int = field(init=False)

    def __post_init__(self):
        self.code_hash = str(hash(self.code))
        # Priority is determined by (step, tree_id)
        # Give higher priority (smaller value) than public evaluation
        # This ensures private evaluation is processed first
        priority_offset = -1000000  # Set a large negative offset
        self.priority = priority_offset + (self.step * 1000 + self.tree_id)

    def __lt__(self, other):
        return self.priority < other.priority


# Wrap the original Chat.send_message to apply rate limiting
def rate_limited_send_message(chat: Chat, message: str):
    """send_message function with rate limiting applied"""
    rate_limit_api_call()  # Apply rate limiting before API call
    send_at = datetime.datetime.now(tz=datetime.timezone.utc)
    response = chat.send_message(message=message)
    if response and response.usage_metadata:
        usage_dir = GLOBAL.path_session / "usage"
        usage_dir.mkdir(parents=True, exist_ok=True)
        timestamp = send_at.strftime("%Y%m%d_%H%M%S_%f")
        with open(usage_dir / f"{timestamp}_{uuid.uuid4()}.json", "w") as f:
            json.dump(response.usage_metadata.model_dump(), f)
    return response


def get_public_eval_result(code: str) -> Result:
    logger.info("[START] Public eval")
    _, code, code_language, judge_version, _, _ = GLOBAL.ale_bench_session._check_run_cases_arguments(
        code=code,
        code_language=GLOBAL.argument_code_language,
    )
    public_case_results = run_cases(
        inputs=GLOBAL.list_cases_exp,
        code=code,
        code_language=code_language,
        judge_version=judge_version,
        time_limit=GLOBAL.ale_bench_session.problem.constraints.time_limit,
        memory_limit=GLOBAL.ale_bench_session.problem.constraints.memory_limit,
        problem_id=GLOBAL.ale_bench_session.problem_id,
        problem_type=GLOBAL.ale_bench_session.problem.metadata.problem_type,
        tool_dir=GLOBAL.ale_bench_session.tool_dir,
        return_details=False,
        skip_local_visualization=True,
        num_workers=GLOBAL.num_workers,
    )
    result = Result(
        allow_score_non_ac=True,
        resource_usage=ResourceUsage(num_call_public_eval=1),
        case_results=public_case_results,
    )
    logger.info("[END] Public eval")
    return result


def get_private_eval_result(code: str) -> tuple[Result, int, float]:
    logger.info("[START] Private eval")

    # If realtime mode is enabled, return default values immediately
    if GLOBAL.realtime:
        logger.info("Realtime mode enabled, using public eval result")
        # Create a dummy Result object
        public_result = get_public_eval_result(code)
        logger.info("[END] Private eval (using public eval result)")
        return public_result, 1, 1

    if GLOBAL.ale_bench_session.session_finished:
        reset_ale_bench_session()
    private_result, rank, performance = GLOBAL.ale_bench_session.private_eval(code, code_language=GLOBAL.argument_code_language)
    logger.info("[END] Private eval")
    reset_ale_bench_session()
    return private_result, rank, performance


class JobQueue:
    """Class to manage a priority job queue"""

    def __init__(self):
        self.jobs = []  # List of jobs managed by heapq (includes both public and private evaluations)
        self.results = {}  # Dictionary to store public evaluation results (tree_id, step, code_hash) -> Result
        self.private_results = {}  # Dictionary to store private evaluation results (tree_id, step, code_hash) -> (Result, rank, performance)
        self.lock = threading.Lock()
        self.condition = threading.Condition(self.lock)
        self.result_conditions = {}  # Condition variables for waiting for specific results (tree_id, step, code_hash) -> Condition
        self.private_result_conditions = {}  # Condition variables for waiting for private evaluation results
        self.workers = []
        self.stop_flag = False

        # Store singleton instance
        global _job_queue_instance
        _job_queue_instance = self

    def add_job(self, job: EvalJob | PrivateEvalJob):
        """Add a job to the queue"""
        with self.lock:
            # Check for existing results based on job type
            if isinstance(job, EvalJob):
                result_key = (job.tree_id, job.step, job.code_hash)
                if result_key in self.results:
                    return

                # Do not add if the same job is already in the queue
                for existing_job in self.jobs:
                    if (
                        isinstance(existing_job, EvalJob)
                        and existing_job.tree_id == job.tree_id
                        and existing_job.step == job.step
                        and existing_job.code_hash == job.code_hash
                    ):
                        return

            elif isinstance(job, PrivateEvalJob):
                result_key = (job.tree_id, job.step, job.code_hash)
                if result_key in self.private_results:
                    return

                # Do not add if the same job is already in the queue
                for existing_job in self.jobs:
                    if (
                        isinstance(existing_job, PrivateEvalJob)
                        and existing_job.tree_id == job.tree_id
                        and existing_job.step == job.step
                        and existing_job.code_hash == job.code_hash
                    ):
                        return

            heapq.heappush(self.jobs, job)
            self.condition.notify()

    def add_job_and_wait(self, job: EvalJob) -> Result:
        """Add a public evaluation job and wait until the result is available"""
        result_key = (job.tree_id, job.step, job.code_hash)

        with self.lock:
            # Delete if result already exists
            if result_key in self.results:
                del self.results[result_key]

            # Create condition variable for waiting for result
            if result_key not in self.result_conditions:
                self.result_conditions[result_key] = threading.Condition(self.lock)

            # Add job to queue (with duplicate check for the same job)
            already_queued = False
            for existing_job in self.jobs:
                if (
                    isinstance(existing_job, EvalJob)
                    and existing_job.tree_id == job.tree_id
                    and existing_job.step == job.step
                    and existing_job.code_hash == job.code_hash
                ):
                    already_queued = True
                    break

            if not already_queued:
                heapq.heappush(self.jobs, job)
                self.condition.notify()

            # Wait until result is available
            while result_key not in self.results:
                self.result_conditions[result_key].wait()

            # Return result
            return self.results[result_key]

    def add_private_job_and_wait(self, job: PrivateEvalJob) -> tuple[Result, int, float]:
        """Add a private evaluation job and wait until the result is available"""
        result_key = (job.tree_id, job.step, job.code_hash)

        with self.lock:
            # Return immediately if result already exists
            if result_key in self.private_results:
                return self.private_results[result_key]

            # Create condition variable for waiting for result
            if result_key not in self.private_result_conditions:
                self.private_result_conditions[result_key] = threading.Condition(self.lock)

            # Add job to queue (with duplicate check for the same job)
            already_queued = False
            for existing_job in self.jobs:
                if (
                    isinstance(existing_job, PrivateEvalJob)
                    and existing_job.tree_id == job.tree_id
                    and existing_job.step == job.step
                    and existing_job.code_hash == job.code_hash
                ):
                    already_queued = True
                    break

            if not already_queued:
                heapq.heappush(self.jobs, job)
                self.condition.notify()

            # Wait until result is available
            while result_key not in self.private_results:
                self.private_result_conditions[result_key].wait()

            # Return result
            return self.private_results[result_key]

    def get_result(self, tree_id: int, step: int, code_hash: str) -> Optional[Result]:
        """Get the result of a public evaluation job"""
        with self.lock:
            return self.results.get((tree_id, step, code_hash))

    def get_private_result(self, tree_id: int, step: int, code_hash: str) -> Optional[tuple[Result, int, float]]:
        """Get the result of a private evaluation job"""
        with self.lock:
            return self.private_results.get((tree_id, step, code_hash))

    def start_workers(self, num_workers: int):
        """Start worker threads"""
        self.stop_flag = False
        for _ in range(num_workers):
            worker = threading.Thread(target=self._worker_loop)
            worker.daemon = True
            self.workers.append(worker)
            worker.start()

    def stop_workers(self):
        """Stop worker threads"""
        with self.lock:
            self.stop_flag = True
            self.condition.notify_all()

        for worker in self.workers:
            worker.join()
        self.workers = []

    def _worker_loop(self):
        """Main loop for worker threads"""
        while True:
            job = None
            with self.lock:
                while not self.jobs and not self.stop_flag:
                    self.condition.wait()

                if self.stop_flag:
                    return

                if self.jobs:
                    job = heapq.heappop(self.jobs)

            if job:
                # Process based on job type
                try:
                    if isinstance(job, EvalJob):
                        # Execute public evaluation job
                        with GLOBAL.session_lock:
                            result = get_public_eval_result(job.code)

                        with self.lock:
                            result_key = (job.tree_id, job.step, job.code_hash)
                            self.results[result_key] = result

                            # Notify threads waiting for this result
                            if result_key in self.result_conditions:
                                self.result_conditions[result_key].notify_all()

                    elif isinstance(job, PrivateEvalJob):
                        # Execute private evaluation job
                        with GLOBAL.session_lock:
                            private_result, rank, performance = get_private_eval_result(job.code)

                        with self.lock:
                            result_key = (job.tree_id, job.step, job.code_hash)
                            self.private_results[result_key] = (private_result, rank, performance)

                            # Notify threads waiting for this result
                            if result_key in self.private_result_conditions:
                                self.private_result_conditions[result_key].notify_all()

                        logger.info(
                            f"  Private evaluation results for node {job.node_expand_idx} (State Idx: {job.state_idx}):"
                        )
                        logger.info(f"    Private judge: {private_result.overall_judge_result}")
                        logger.info(f"    Private absolute score: {private_result.overall_absolute_score}")
                        logger.info(f"    Relative score: {private_result.overall_relative_score}")
                        logger.info(f"    Rank: {rank}")
                        logger.info(f"    Performance: {performance}")

                except Exception as e:
                    logger.error(f"Error in worker thread: {e}")
                    import traceback

                    traceback.print_exc()


# Global JobQueue instance
_job_queue_instance = None


def get_job_queue() -> JobQueue:
    """Get the global JobQueue instance"""
    global _job_queue_instance
    if _job_queue_instance is None:
        _job_queue_instance = JobQueue()
        _job_queue_instance.start_workers(GLOBAL.num_workers)
    return _job_queue_instance


def generate_code_repeat(
    llm_chat: Chat,
    num_code_patience: int = GLOBAL.num_code_patience,
    idx_state: int = GLOBAL.num_states,
    tree_id: int = 0,  # Tree ID (default is 0)
    step: int = 0,  # Current step (default is 0)
    break_on_AC: bool = True,
) -> tuple[str, Result] | tuple[None, None]:
    if GLOBAL.verbose:
        if break_on_AC:
            logger.info("Generating code until AC...")
        else:
            logger.info(f"Generating code... with {num_code_patience} turns")
    initial_message = (
        f"Please implement your above solution in {LANG_EXPR[GLOBAL.argument_code_language]} and {GLOBAL.score_type} absolute scores as much as possible. "
        f"The execution time limit is {GLOBAL.time_limit} second. "
        f"Your submission code should be written in the {CODE_BLOCK[GLOBAL.argument_code_language]} code block. "
    )
    if not break_on_AC:
        initial_message += f"You have {num_code_patience} turns to implement your code and to improve it. "

    message = initial_message

    # Get job queue
    job_queue = get_job_queue()

    # Variable to hold the last code and evaluation result
    list_code_result = []

    for idx_turn in range(num_code_patience):
        if GLOBAL.verbose:
            logger.info(f"Generating code... ({idx_turn + 1}/{num_code_patience})")
        submission_response = rate_limited_send_message(llm_chat, message)  # Use rate-limited send_message
        save_history(llm_chat, idx_state)
        code = get_code(submission_response.text)
        if code is None:
            if GLOBAL.verbose:
                if idx_turn == num_code_patience - 1:
                    logger.info("No valid code block found. Exiting")
                else:
                    logger.info("No valid code block found. Regenerating...")
            message = (
                f"No valid code block found. Please implement your solution in the {CODE_BLOCK[GLOBAL.argument_code_language]} code block. "
            ) + initial_message
            continue

        # Execute public evaluation using job queue
        job = EvalJob(code=code, tree_id=tree_id, step=step, state_idx=idx_state)
        public_result = job_queue.add_job_and_wait(job)

        list_code_result.append((public_result, None, code, None))  # format for select_best_submission

        if public_result.overall_judge_result == JudgeResult.ACCEPTED:
            if break_on_AC:
                if GLOBAL.verbose:
                    logger.info("AC! Returning code...")
                return code, public_result
            else:
                if GLOBAL.verbose:
                    logger.info("AC! But not returning code because break_on_AC is False")
                message = feedback_message(public_result)
        else:
            if GLOBAL.verbose:
                if idx_turn == num_code_patience - 1:
                    logger.info(f"Not AC. Got {public_result.overall_judge_result.value}. Exiting...")
                else:
                    logger.info(f"Not AC. Got {public_result.overall_judge_result.value}. Regenerating...")
            message = feedback_message(public_result)

        if not break_on_AC:
            message += (
                f"\nYou have {num_code_patience - idx_turn - 1} turns left to implement your code and to improve it."
            )

    # Return the last generated code and result even if not AC
    if GLOBAL.verbose:
        logger.info("Returning best code")
    public_result, _, code = select_best_submission(list_code_result, GLOBAL.score_type)
    return code, public_result


def generate_action_initial(parent_state: State) -> Tuple[State, float]:
    """Function to generate the initial state (expected to be called from BestFirstSearch)"""
    assert parent_state.is_root, "Parent state must be a root state"

    # Create state with its own idx (no need to specify idx parameter)
    state_new = State(parent_state=parent_state)
    if GLOBAL.verbose:
        logger.info(f"Generating {parent_state.idx} -> {state_new.idx}")

    llm_chat = get_chat()
    message = [
        (
            "Below is the problem statement. First, please analyze the problem statement and consider the solution. "
            "You do not have to implement the solution yet. Instead, please think about the essential points of the problem "
            "and possible algorithms to get higher rank in the contest.\\n\\n"
            f"{PROMPT_ENFORCE}\\n\\n"
            "[Problem statement]\\n"
        )
    ]
    message += GLOBAL.problem_statement
    solution_response = rate_limited_send_message(llm_chat, message)  # Use rate-limited send_message
    current_state_idx = state_new.idx
    save_history(llm_chat, current_state_idx)

    # Get tree ID and step from parent node (for root node, tree ID is its own ID)
    tree_id = parent_state.idx  # Root node's idx becomes tree ID
    step = 0  # Step is 0 for initial state

    code, public_result = generate_code_repeat(
        llm_chat, GLOBAL.num_code_patience, current_state_idx, tree_id, step, break_on_AC=GLOBAL.break_on_AC
    )
    save_history(llm_chat, current_state_idx)

    if code is None:
        # Handle code generation failure: set score to minimum and create State object
        state_new.chat_history = get_chat_history(llm_chat)
        state_new.score = GLOBAL.worst_score
        state_new.feedback = "Code generation failed."  # Add feedback
        save_state(state_new)  # Save failed state
        return state_new, GLOBAL.worst_score

    # Generate feedback from public_result
    feedback_str = result_feedback(public_result)  # <<< Added
    score = public_result.overall_absolute_score

    if GLOBAL.verbose:
        logger.info("Initial Node Generation:")
        logger.info(f"  State Idx: {state_new.idx}")
        logger.info(f"  Public judge: {public_result.overall_judge_result}")
        logger.info(f"  Public absolute score: {public_result.overall_absolute_score}")
        logger.info(f"  Feedback: {feedback_str}")  # Also output to log

    state_new.code = code
    state_new.score = score
    state_new.chat_history = get_chat_history(llm_chat)
    state_new.absolute_score = public_result.overall_absolute_score
    state_new.feedback = feedback_str  # <<< Save feedback to state
    state_new.public_result = public_result  # <<< Save Result object

    # Update code history
    update_code_history(tree_id, public_result, GLOBAL.argument_code_language, code, feedback_str)

    # Set initial summary
    initial_summary = "Your first submission has been completed. From here, we will start recording your attempts."
    update_summary_history(tree_id, initial_summary)

    save_state(state_new)
    return state_new, score


def generate_action_llm(parent_state: State, num_step=10) -> Tuple[State, float]:
    """Function to generate a new state from an existing state using LLM"""
    state = State(parent_state=parent_state)
    state.tree_id = parent_state.tree_id
    state.step = parent_state.step + 1  # Increment step

    if GLOBAL.verbose:
        logger.info(f"Generating idx {parent_state.idx} -> {state.idx}")

    llm_chat = get_chat()
    job_queue = get_job_queue()

    message_initial = [
        (
            "Below is the problem statement. I want you to carefully analyze it and develop an optimized solution strategy.\\n\\n"
            "[PROBLEM STATEMENT]\\n"
        )
    ]
    message_initial += GLOBAL.problem_statement

    if parent_state.is_root:
        message_initial += ["Now, please generate a new solution strategy to get a high rank."]
    else:
        # Add parent node's feedback to prompt
        parent_feedback_str = (
            parent_state.feedback if parent_state.feedback else "No feedback available for the previous code."
        )  # <<< Added
        message_initial += [
            "The following is a previous implementation of the solution.\\n"
            f"{CODE_BLOCK_FORMAT[GLOBAL.argument_code_language].format(parent_state.code)}\\n\\n"
            f"The code was evaluated with {GLOBAL.num_cases_exp} cases.\\n"
            f"[Result] Public Score: {parent_state.score}\\n"
            f"[Feedback]\\n{parent_feedback_str}\\n\\n"  # <<< Add feedback
            "Now, please analyze the feedback and generate a new solution strategy based on the previous implementation and feedback. "  # Also fix prompt
            "You can change the algorithm, data structure, or any other part of the code. "
            "You can also add new algorithms or data structures, or any other new features to improve the score. "
            "You must also include discussions about the limitation of the solution and possible improvements that you want to refer during later steps. "
        ]

    message_initial += [f"You have {num_step} steps in total to generate a new solution strategy. "]

    public_result_best = None
    code_best = None
    feedback_best = None  # Also keep best feedback

    previous_feedback = ""  # This variable will be replaced by feedback generated in the loop

    # Variable to save the last generated code and result (keep even if not AC)
    last_code = None
    last_result = None
    last_feedback = None

    for idx_step in range(1, num_step + 1):
        logger.info(f"Idx: {state.idx} \\tStep {idx_step}/{num_step}")
        if idx_step == 1:
            message = message_initial.copy()
        else:
            # Create message using previous feedback
            message = [
                f"Based on the previous attempt's feedback:\\n{previous_feedback}\\n\\nPlease refine the solution."  # <<< Change to simple message
            ]

        message += [
            f"The current step is {idx_step} out of total {num_step} steps. "
            # f"{previous_feedback}\\n\\n" # previous_feedback was added above, so delete
            "You may implement the solution right now, or you may plan strategies. "
            f"If you will implement the solution, please implement your above solution in {LANG_EXPR[GLOBAL.argument_code_language]} and {GLOBAL.score_type} absolute scores as much as possible. "
            f"The execution time limit is {GLOBAL.time_limit} second. "
            f"Your submission code should be written in the {CODE_BLOCK[GLOBAL.argument_code_language]} code block. "
            f"Your submission is evaluated with {GLOBAL.num_cases_exp} cases.\\n"
            # f"A promising strategy, is to first implement a high quality greedy solution, and then extend it to beam search. "
            f"In your 10 steps, first implement a high quality greedy solution."
            f"**If you already have a strong greedy solution, then extend your solution to beam search. You can easily extend a greedy solution to beam search by holding multiple candidates instead of one during the greedy search.** "
            f"**Please think deeply and broadly about the possible improvements, and output your thoughts in detail, while also implementing your idea.** "
            f"**You must include discussions that you want to refer during later steps.** "
        ]
        save_history(llm_chat, state.idx)
        submission_response = rate_limited_send_message(llm_chat, message)
        save_history(llm_chat, state.idx)
        code = get_code(submission_response.text)
        if code is None:
            previous_feedback = (
                "No code block found in the response. Please provide the code."  # <<< Set feedback
            )
        else:
            job = EvalJob(code=code, tree_id=state.tree_id, step=state.step, state_idx=state.idx)
            public_result = job_queue.add_job_and_wait(job)

            # Save code and result each time
            last_code = code
            last_result = public_result

            # Generate feedback for current attempt
            current_feedback = result_feedback(public_result)  # <<< Use result_feedback
            previous_feedback = current_feedback  # Keep for next loop
            last_feedback = current_feedback

            if public_result.overall_judge_result == JudgeResult.ACCEPTED:
                if (public_result_best is None) or (is_first_result_better(public_result, public_result_best)):
                    public_result_best = public_result
                    code_best = code
                    feedback_best = current_feedback  # Also update best feedback
                # feedback = f"Previous implementation is accepted. Overall score: {public_result.overall_absolute_score}" # No longer needed
                logger.info(
                    f"[Result] Idx: {state.idx} \\tStep {idx_step}/{num_step} \\tPublic Score: {public_result.overall_absolute_score}"
                )
            else:
                # feedback variable was replaced by current_feedback
                logger.info(
                    f"[Result] Idx: {state.idx} \\tStep {idx_step}/{num_step} \\tJudge: {public_result.overall_judge_result}"
                )
                # Detailed feedback for errors is included in current_feedback

            # previous_feedback = feedback # No longer needed

    # Improve processing after loop ends
    if public_result_best is not None:
        # If AC was obtained at least once, use the best result
        logger.info(f"Node Generation Completed: {parent_state.idx} -> {state.idx}:")
        logger.info(f"  Public judge: {public_result_best.overall_judge_result}")
        logger.info(f"  Public absolute score: {public_result_best.overall_absolute_score}")

        state.code = code_best
        state.score = public_result_best.overall_absolute_score
        state.chat_history = get_chat_history(llm_chat)
        state.absolute_score = public_result_best.overall_absolute_score
        state.feedback = feedback_best  # Save feedback for the best result
        save_state(state)

        return state, public_result_best.overall_absolute_score
    elif last_code is not None:
        # If AC was not obtained but code was generated, use the last code
        logger.info(f"No AC, but using last generated code: {parent_state.idx} -> {state.idx}")
        logger.info(f"  Public judge: {last_result.overall_judge_result}")
        logger.info(f"  Public absolute score: {last_result.overall_absolute_score}")

        state.code = last_code
        state.score = last_result.overall_absolute_score
        state.chat_history = get_chat_history(llm_chat)
        state.absolute_score = last_result.overall_absolute_score
        state.feedback = last_feedback
        save_state(state)

        return state, last_result.overall_absolute_score
    else:
        # If no code was generated at all
        logger.warning(f"Idx {state.idx} failed to generate any code within {num_step} steps.")
        state.code = None
        state.score = GLOBAL.worst_score
        state.chat_history = get_chat_history(llm_chat)
        state.absolute_score = None
        state.feedback = "Failed to generate any code."
        save_state(state)

        return None, None


def generate_action(parent_state: State) -> Tuple[State, float]:
    """Function to generate a new state from an existing state (expected to be called from BestFirstSearch)"""
    # return generate_action_llm(parent_state) # Enable this to use the llm version

    if parent_state.is_root:
        # If parent node is root, call generate_action_initial
        return generate_action_initial(parent_state)

    if parent_state.code is None:
        # If parent node doesn't have valid code, treat as error
        logger.warning("Warning: parent_state.code is None in generate_action. Returning minimum score.")

        # Create failed state using its own idx
        failed_state = State(score=GLOBAL.worst_score, parent_state=parent_state)
        failed_state.feedback = "Parent state had no code."  # Add feedback
        save_state(failed_state)
        return failed_state, GLOBAL.worst_score

    # Create new state with its own idx
    state_new = State(parent_state=parent_state)

    # Inherit parent's tree ID and step (step should be incremented)
    state_new.tree_id = parent_state.tree_id
    state_new.step = parent_state.step + 1  # <<< Increment step

    if GLOBAL.verbose:
        logger.info(f"Generating {parent_state.idx} -> {state_new.idx}")

    llm_chat = get_chat()

    # If using summary feature
    if GLOBAL.use_summary:
        # Get last summary
        last_summary = get_last_summary(state_new.tree_id)

        # Generate prompt with feedback
        message = create_feedback_with_summary(state_new.tree_id, GLOBAL.problem_statement, last_summary)

        # Send prompt to LLM
        solution_response = rate_limited_send_message(llm_chat, [message])
        current_state_idx = state_new.idx
        save_history(llm_chat, current_state_idx)

        # Extract summary from response
        new_summary = get_summary_from_response(solution_response.text)
        if new_summary:
            update_summary_history(state_new.tree_id, new_summary)
            if GLOBAL.verbose:
                logger.info(
                    f"Found new summary for tree {state_new.tree_id}, step {state_new.step}: {new_summary[:100]}..."
                )
        else:
            # Warn if summary is not found
            logger.warning(f"No summary found in response for tree {state_new.tree_id}, step {state_new.step}")
            # Reuse previous summary
            update_summary_history(state_new.tree_id, last_summary)
    else:
        # Use traditional prompt method
        message = [
            (
                "Below is the problem statement. I want you to carefully analyze it and develop an optimized solution strategy.\\n\\n"
                "[PROBLEM STATEMENT]\\n"
            )
        ]
        message += GLOBAL.problem_statement

        # Get parent node's feedback
        parent_feedback_str = (
            parent_state.feedback if parent_state.feedback else "No feedback available for the previous code."
        )

        # Get template using function from prompt_templates module
        message_templates_with_weights = get_message_templates_with_weights(
            CODE_BLOCK_FORMAT[GLOBAL.argument_code_language], parent_state.code, parent_feedback_str, PROMPT_ENFORCE
        )

        # Weighted random selection
        templates, weights = zip(*message_templates_with_weights)
        selected_template = random.choices(templates, weights=weights, k=1)[0]
        # PROMPT_ENFORCE is already added in get_message_templates_with_weights, so no need
        # selected_template = selected_template + f"\\n\\n{PROMPT_ENFORCE}\\n\\n"
        message += [selected_template]

        solution_response = rate_limited_send_message(llm_chat, message)
        current_state_idx = state_new.idx
        save_history(llm_chat, current_state_idx)

    # Pass parent's tree ID and step to generate_code_until_AC (use state_new's step)
    code, public_result = generate_code_repeat(
        llm_chat,
        GLOBAL.num_code_patience,
        current_state_idx,
        state_new.tree_id,
        state_new.step,  # <<< Use state_new's step
        break_on_AC=GLOBAL.break_on_AC,
    )
    save_history(llm_chat, current_state_idx)

    if code is None:
        # Code generation failed
        state_new.chat_history = get_chat_history(llm_chat)
        state_new.score = GLOBAL.worst_score
        state_new.feedback = "Code generation failed based on the previous feedback."  # Add feedback
        save_state(state_new)
        return state_new, GLOBAL.worst_score

    # Generate feedback from public_result
    feedback_str = result_feedback(public_result)  # <<< Added
    score = public_result.overall_absolute_score

    if GLOBAL.verbose:
        logger.info(f"Node Generation (Parent Idx: {parent_state.idx}):")
        logger.info(f"  State Idx: {state_new.idx}")
        logger.info(f"  Public judge: {public_result.overall_judge_result}")
        logger.info(f"  Public absolute score: {public_result.overall_absolute_score}")
        logger.info(f"  Feedback: {feedback_str}")  # Also output to log

    state_new.code = code
    state_new.score = score
    state_new.chat_history = get_chat_history(llm_chat)
    state_new.absolute_score = public_result.overall_absolute_score
    state_new.feedback = feedback_str  # <<< Save feedback to state
    state_new.public_result = public_result  # <<< Save Result object

    # Update code history
    update_code_history(state_new.tree_id, public_result, GLOBAL.argument_code_language, code, feedback_str)

    save_state(state_new)
    return state_new, score


class ParallelBestFirstSearch:
    """Class to search multiple independent trees in parallel"""

    def __init__(self, num_trees: int = 1, num_samples: int = 1, num_parallel: int = GLOBAL.num_expansion_threads):
        self.num_trees = num_trees
        self.num_samples = num_samples
        self.num_parallel = num_parallel
        self.trees = []  # Multiple search trees
        self.job_queue = get_job_queue()  # Use existing job queue
        self.tree_locks = [threading.Lock() for _ in range(num_trees)]  # Lock for each tree
        self.tree_steps = [0 for _ in range(num_trees)]  # Current step for each tree
        self.tree_threads = []  # Search thread for each tree
        self.stop_flag = False  # Search stop flag
        self.max_steps = 0  # Maximum number of steps
        self.stop_event = threading.Event()  # Search stop event
        self.best_overall_lock = threading.Lock()  # Add lock for saving overall best code
        self.overall_best_score = None  # Track overall best score

    def initialize(self, root_states: List[Any], initial_scores: List[float]):
        """Initialize multiple search trees"""
        if len(root_states) != self.num_trees or len(initial_scores) != self.num_trees:
            raise ValueError("Root states and initial scores must match the number of trees")

        # Initialize each tree
        self.trees = []
        for i in range(self.num_trees):
            search_algo = BestFirstSearch(num_samples=self.num_samples, num_parallel=self.num_parallel)
            search_algo.initialize(root_states[i], initial_scores[i])
            self.trees.append(search_algo)

    def _tree_worker(self, tree_id: int, generate_action_fn: callable) -> None:
        """Worker function to search each tree asynchronously"""
        logger.info(f"Tree {tree_id} worker started")

        while not self.stop_event.is_set() and self.tree_steps[tree_id] < self.max_steps:
            # ---- Add time limit check ----
            if GLOBAL.end_time is not None and time.time() >= GLOBAL.end_time:
                logger.info(f"Tree {tree_id}: Time limit reached. Stopping worker.")
                self.stop_event.set()  # Notify other threads to stop
                break
            # -----------------------------

            # Get current step
            current_step = self.tree_steps[tree_id]

            # Expand tree by 1 step
            self._expand_tree(tree_id, generate_action_fn, current_step)

            # Advance step
            with self.tree_locks[tree_id]:
                self.tree_steps[tree_id] += 1
                new_step = self.tree_steps[tree_id]

            # Tree visualization (optional)
            visualize_search_tree(
                self.trees[tree_id].tree,
                str(GLOBAL.path_session / f"tree_{tree_id}_step_{current_step}"),
                title=f"Tree {tree_id} at Step {current_step}",
            )

            # Save best code at the end of the step
            try:
                nodes = self.trees[tree_id].tree.get_nodes()
                if nodes:
                    # Use the same priority criteria as search priority
                    valid_nodes = [n for n in nodes if n.state and n.state.code]
                    if valid_nodes:
                        # priority is smaller is better (designに合わせてheapqの仕様), so use min function
                        best_node_in_tree = min(valid_nodes, key=lambda n: n.priority)

                        if best_node_in_tree.state and best_node_in_tree.state.code:
                            best_code_filename = f"best_tree_{tree_id}.{LANG_FILE[GLOBAL.argument_code_language]}"
                            best_code_path = GLOBAL.path_session / best_code_filename
                            with open(best_code_path, "w") as f:
                                f.write(best_node_in_tree.state.code)
                            logger.info(
                                f"Tree {tree_id}: Best code for step {current_step} (node id {best_node_in_tree.expand_idx}) saved to {best_code_path}"
                            )
                            logger.info(
                                f"  Selection criteria: Priority={best_node_in_tree.priority}, AC ratio={best_node_in_tree.get_ac_ratio():.2f}, Score={best_node_in_tree.score:.4f}, Depth={best_node_in_tree.depth}"
                            )

                # 2. Identify and save the best code from all trees (additional processing)
                with self.best_overall_lock:  # Exclusive control
                    current_overall_best_node = None

                    # Check all trees
                    for search_algo in self.trees:
                        all_nodes = search_algo.tree.get_nodes()
                        valid_nodes = [n for n in all_nodes if n.state and n.state.code]
                        if not valid_nodes:
                            continue

                        # Find the best node from each tree - smaller priority is better
                        candidate_best = min(valid_nodes, key=lambda n: n.priority)

                        # Compare with the current overall best node - smaller priority is better
                        if current_overall_best_node is None:
                            current_overall_best_node = candidate_best
                        elif candidate_best.priority < current_overall_best_node.priority:
                            current_overall_best_node = candidate_best

                    # Save if the selected node is valid
                    if (
                        current_overall_best_node is not None
                        and current_overall_best_node.state
                        and current_overall_best_node.state.code
                    ):
                        # Update best score
                        self.overall_best_score = current_overall_best_node.score
                        best_code_overall_path = GLOBAL.path_session / f"best.{LANG_FILE[GLOBAL.argument_code_language]}"
                        with open(best_code_overall_path, "w") as f:
                            f.write(current_overall_best_node.state.code)
                        logger.info(
                            f"Overall best node (node id {current_overall_best_node.expand_idx}) updated at step {current_step}: Priority={current_overall_best_node.priority}, "
                            f"AC ratio={current_overall_best_node.get_ac_ratio():.2f}, Score={current_overall_best_node.score:.4f}, Depth={current_overall_best_node.depth}. "
                            f"Saved best code to {best_code_overall_path}"
                        )

            except Exception as e:
                logger.error(f"Tree {tree_id}: Error saving best code for step {current_step}: {e}")
                import traceback

                traceback.print_exc()

            logger.info(f"Tree {tree_id} completed step {current_step}, moving to step {new_step}")

    def _expand_tree(self, tree_id: int, generate_action_fn: callable, step: int) -> None:
        """Function to expand the specified tree by 1 step"""
        try:
            # Expand tree at the current step
            search_algo = self.trees[tree_id]

            # If queue is empty, it's complete
            with search_algo.priority_queue_lock:
                if not search_algo.priority_queue:
                    logger.info(f"Tree {tree_id} queue is empty at step {step}, no more nodes to expand")
                    return

            # Get next node
            current_node = None
            with search_algo.priority_queue_lock:
                if search_algo.priority_queue:
                    current_node = heapq.heappop(search_algo.priority_queue)

            if current_node is None:
                logger.info(f"Tree {tree_id} has no nodes to expand at step {step}")
                return

            # Record expansion history
            with search_algo.expand_history_lock:
                search_algo.expand_history.append(current_node.expand_idx)

            # Generate child nodes in parallel
            futures = []
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_parallel) as executor:
                # Evaluate node
                if current_node.state and not current_node.state.is_root and current_node.state.code:
                    future = executor.submit(search_algo.evaluate_node, current_node)
                    futures.append(future)

                for _ in range(self.num_samples):
                    future = executor.submit(
                        self._generate_child, search_algo, current_node, generate_action_fn, tree_id, step
                    )
                    futures.append(future)

                # Wait for all tasks to complete
                concurrent.futures.wait(futures)

            if GLOBAL.verbose:
                logger.info(
                    f"Tree {tree_id}, Step {step}: Expanded node with Idx: {current_node.expand_idx}, "
                    f"Score: {current_node.score:.4f}, State Idx: {current_node.state.idx if current_node.state else 'None'}"
                )
                logger.info(
                    f"  Generated up to {self.num_samples} new child node(s) using {self.num_parallel} thread(s)"
                )

        except Exception as e:
            logger.error(f"Error expanding tree {tree_id} at step {step}: {e}")

    def _generate_child(
        self,
        search_algo: BestFirstSearch,
        current_node: SearchNode,
        generate_action_fn: callable,
        tree_id: int,
        step: int,
    ) -> None:
        """Function to generate a child node"""
        try:
            # Generate new state
            new_state, new_score = generate_action_fn(current_node.state)

            # Change check for new_state is None
            if new_state is not None:
                # Associate tree ID and step information with the generated state
                new_state.tree_id = tree_id
                # new_state.step = step # Not needed as step is incremented in generate_action function

                # Add node to tree
                new_node = search_algo.tree.add_node(new_state, new_score, current_node)

                # Add to priority queue
                with search_algo.priority_queue_lock:
                    heapq.heappush(search_algo.priority_queue, new_node)

                # # If code exists, add job for public eval
                # if new_state.code:
                #     job = EvalJob(code=new_state.code, tree_id=tree_id, step=step, state_idx=new_state.idx)
                #     self.job_queue.add_job(job)

                if GLOBAL.verbose:
                    score_quality = "higher is better" if GLOBAL.score_type == ScoreType.MAXIMIZE else "lower is better"
                    logger.info(
                        f"Tree {tree_id}, Step {new_state.step}: Created child node with State Idx: {new_state.idx}, "  # Fix step display
                        f"Score: {new_score:.4f} ({score_quality})"
                    )
                    if new_state.code is None:
                        logger.info("  Note: This node has no code associated with it.")

        except Exception as e:
            logger.error(f"Error in child generation for tree {tree_id} at step {step}: {e}")
            import traceback

            traceback.print_exc()

    def run_steps(self, num_steps: int, generate_action_fn: callable) -> None:
        """Execute search for the specified number of steps (each tree progresses asynchronously)"""
        self.max_steps = num_steps
        self.stop_event.clear()

        logger.info(f"Starting parallel search with {self.num_trees} trees for up to {num_steps} steps each")

        # Start thread for each tree
        self.tree_threads = []
        for tree_id in range(self.num_trees):
            thread = threading.Thread(
                target=self._tree_worker, args=(tree_id, generate_action_fn), name=f"Tree-{tree_id}-Worker"
            )
            thread.daemon = True  # Terminate when main thread terminates
            thread.start()
            self.tree_threads.append(thread)
            logger.info(f"Started worker thread for tree {tree_id}")

        # Wait for all tree threads to complete
        for thread in self.tree_threads:
            thread.join()

        logger.info("All tree worker threads have completed")

    def stop(self):
        """Stop search"""
        logger.info("Stopping parallel search")
        self.stop_event.set()

        # Wait for all threads to complete
        for thread in self.tree_threads:
            thread.join()

        logger.info("Parallel search stopped")

    def get_best_nodes(self) -> List[Tuple[SearchNode, int]]:
        """Get the best node from each tree"""
        best_nodes = []

        for tree_id, search_algo in enumerate(self.trees):
            nodes = search_algo.tree.get_nodes()
            if not nodes:
                continue

            # Filter only nodes with code
            valid_nodes = [n for n in nodes if n.state and n.state.code]
            if not valid_nodes:
                continue

            # Since priority is defined, use it to select the best node
            # Smaller priority is better (designに合わせてheapqの仕様), so use min function
            best_node = min(valid_nodes, key=lambda n: n.priority)
            best_nodes.append((best_node, tree_id))

        return best_nodes


def run_parallel(
    problem_id: str,
    lite_version: bool,
    argument_code_language: str,
    model: str,
    num_trees: int = 2,
    num_workers: int = 16,
    verbose: bool = False,
    num_steps: int = 10,
    num_samples: int = 1,
    num_parallel: int = None,  # Number of threads for parallel expansion
    realtime: bool = False,  # Realtime mode flag
    duration: int = None,  # Maximum execution duration in minutes
    use_summary: bool = True,  # Whether to use the summary feature
    break_on_ac: bool = True,  # Added: Whether to stop generation when AC is achieved
    use_domain_knowledge: bool = False,  # Added: Whether to use domain knowledge prompts
) -> None:
    """
    Execute multiple independent search trees in parallel to find the optimal solution.
    """
    # Set number of parallel threads
    if num_parallel is not None:
        GLOBAL.num_expansion_threads = num_parallel
    # Update GLOBAL.realtime
    GLOBAL.realtime = realtime
    # Set summary feature
    GLOBAL.use_summary = use_summary
    # Add break_on_AC setting
    GLOBAL.break_on_AC = break_on_ac
    # Add domain knowledge prompt setting
    GLOBAL.use_domain_knowledge = use_domain_knowledge

    update_global_variables(problem_id, lite_version, argument_code_language, model, num_workers, verbose, realtime, break_on_ac)

    # Set execution time limit
    if duration is not None:
        start_time = time.time()
        GLOBAL.end_time = start_time + duration * 60
        logger.info(f"Setting execution time limit to {duration} minutes.")
        logger.info(
            f"Program will terminate around {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(GLOBAL.end_time))}."
        )

    if GLOBAL.use_summary:
        logger.info("Summary feature is enabled. LLM will use code history and summaries for better prompts.")
    else:
        logger.info("Summary feature is disabled. Using traditional prompts.")

    if GLOBAL.use_domain_knowledge:
        logger.info("Domain knowledge prompts are enabled. LLM will use specialized prompts for solution refinement.")
    else:
        logger.info("Domain knowledge prompts are disabled. Using generic improvement guidance.")

    # Initialize global job queue
    job_queue = get_job_queue()

    # Start state index from 0
    with GLOBAL.state_lock:
        GLOBAL.num_states = 0

    # Create root state for each tree
    root_states = []
    initial_scores = []
    for i in range(num_trees):
        root_state = State(
            idx=i, is_root=True, score=0.0, tree_id=i, step=0, feedback="Initial root state."
        )  # Add simple feedback to root as well
        save_state(root_state)
        root_states.append(root_state)
        initial_scores.append(0.0)
        with GLOBAL.state_lock:
            GLOBAL.num_states += 1

    # Initialize parallel search algorithm
    search_algo = ParallelBestFirstSearch(
        num_trees=num_trees, num_samples=num_samples, num_parallel=GLOBAL.num_expansion_threads
    )
    search_algo.initialize(root_states, initial_scores)

    logger.info("Parallel search initialized with multiple root states.")
    logger.info(f"Created {num_trees} independent search trees that will run asynchronously.")
    logger.info(
        f"Will generate up to {num_samples} child nodes per expansion with {GLOBAL.num_expansion_threads} parallel threads."
    )

    try:
        # Execute asynchronous parallel search
        search_algo.run_steps(num_steps, generate_action)
    except KeyboardInterrupt:
        # If interrupted by Ctrl+C
        logger.info("Search interrupted by user.")
        search_algo.stop()
    except Exception as e:
        logger.error(f"Error during search: {e}")
        import traceback

        traceback.print_exc()
        search_algo.stop()

    # Stop workers
    job_queue.stop_workers()

    logger.info("Parallel search finished.")

    # Save code history and summary history
    if GLOBAL.use_summary:
        for tree_id in range(num_trees):
            # Save code history
            if tree_id in GLOBAL.codes_history:
                codes_history_path = GLOBAL.path_session / f"codes_history_tree_{tree_id}.json"
                with open(codes_history_path, "w") as f:
                    codes_history_json = []
                    for result, code_language, code, feedback in GLOBAL.codes_history[tree_id]:
                        codes_history_json.append(
                            [None if result is None else result.model_dump(), code_language, code, feedback]
                        )
                    json.dump(codes_history_json, f, indent=4)

            # Save summary history
            if tree_id in GLOBAL.summary_history:
                summary_history_path = GLOBAL.path_session / f"summary_history_tree_{tree_id}.json"
                with open(summary_history_path, "w") as f:
                    json.dump(GLOBAL.summary_history[tree_id], f, indent=4)

    # Collect best results from all trees
    best_nodes = search_algo.get_best_nodes()
    if best_nodes:
        # Sort based on score
        if GLOBAL.score_type == ScoreType.MAXIMIZE:
            best_nodes.sort(key=lambda x: x[0].score, reverse=True)
        else:
            best_nodes.sort(key=lambda x: x[0].score)

        # Display best result
        best_node, best_tree_id = best_nodes[0]
        logger.info(
            f"Best overall score: {best_node.score:.4f} from Tree {best_tree_id} (Expand Idx: {best_node.expand_idx})"
        )
        if best_node.state:  # Check if state exists
            logger.info(f"  State Idx: {best_node.state.idx}")

            # Save best code
            if best_node.state.code:
                best_code_path = GLOBAL.path_session / f"best_code_{best_node.state.idx}.{LANG_FILE[GLOBAL.argument_code_language]}"
                with open(best_code_path, "w") as f:
                    f.write(best_node.state.code)
                logger.info(f"  Best code saved to: {best_code_path}")
            else:
                logger.info("  Best node has no code associated.")
        else:
            logger.info("  Best node has no state associated.")

        # Final tree visualization
        for tree_id in range(search_algo.num_trees):
            visualize_search_tree(
                search_algo.trees[tree_id].tree,
                str(GLOBAL.path_session / f"tree_{tree_id}_final"),
                title=f"Final Tree {tree_id}",
            )


def get_summary_from_response(response_text: str) -> Optional[str]:
    """Extract summary block from LLM response"""
    try:
        # Extract Markdown code block
        summary_match = re.compile(r"```md\n(.+?)\n```", re.DOTALL).findall(response_text)
        if len(summary_match) == 0:
            return None
        return summary_match[-1]  # Return the last match
    except Exception as e:
        logger.error(f"Ignoring error in get_summary_from_response and returning None: {e}")
        return None


def select_best_submission(
    codes_history: List[Tuple[Optional[Result], str, str, str]], score_type: ScoreType
) -> Tuple[Optional[Result], str, str]:
    """Select the best code from the given history"""
    if not codes_history:
        return None, GLOBAL.argument_code_language, ""

    best_score = float("-inf") if score_type == ScoreType.MAXIMIZE else float("inf")
    best_result, best_code_language, best_code = None, GLOBAL.argument_code_language, ""

    for result, code_language, code, _ in codes_history:
        if result is not None and result.overall_judge_result == JudgeResult.ACCEPTED:
            score = result.overall_absolute_score
            if (score_type == ScoreType.MAXIMIZE and score > best_score) or (
                score_type == ScoreType.MINIMIZE and score < best_score
            ):
                best_score = score
                best_result = result
                best_code_language = code_language
                best_code = code

    # If no AC, return the latest code
    if best_result is None and codes_history:
        return codes_history[-1][0], codes_history[-1][1], codes_history[-1][2]

    return best_result, best_code_language, best_code


def create_feedback_with_summary(
    tree_id: int, problem_statement: list[dict[str, str | dict[str, str]]], last_summary: Optional[str] = None
) -> str:
    """Generate prompt with feedback using code history and summary"""
    with GLOBAL.codes_history_lock:
        # If no history for this tree, create an empty list
        if tree_id not in GLOBAL.codes_history:
            return "No history yet. Please generate your first code."

        codes_history = GLOBAL.codes_history[tree_id]
        if not codes_history:
            return "No history yet. Please generate your first code."

        # Get best code
        best_result, best_code_language, best_code = select_best_submission(codes_history, GLOBAL.score_type)

        # Latest code
        latest_result, latest_code_language, latest_code, _ = codes_history[-1]

        # Generate feedback
        best_feedback = result_feedback(best_result) if best_result is not None else "No submissions yet."
        latest_feedback = result_feedback(latest_result) if latest_result is not None else "No submissions yet."

        # If latest code is the same as best code
        if latest_code == best_code:
            latest_code_display = "The latest code is the same as the best code."
            latest_feedback_display = "The latest feedback is the same as the best feedback."
        else:
            latest_code_display = latest_code
            latest_feedback_display = latest_feedback

        # Default message if no summary
        action_summary = (
            last_summary
            if last_summary is not None
            else "No summary of previous attempts found. Please write your summary in the ```md\n<!-- Your summary here -->\n``` code block."
        )

        # Determine improvement guidance (branch based on whether to use domain knowledge)
        improvement_guidance = DEFAULT_IMPROVEMENT_GUIDANCE
        if GLOBAL.use_domain_knowledge:
            # Select from directly defined improvement guidance proposals
            improvement_templates, improvement_weights = zip(*get_improvement_guidance_with_weights(PROMPT_ENFORCE))
            selected_template = random.choices(improvement_templates, weights=improvement_weights, k=1)[0]
            improvement_guidance = selected_template
            if GLOBAL.verbose:
                logger.info(f"Using domain knowledge prompt for tree {tree_id}")

        # Create prompt - use FEEDBACK_WITH_SUMMARY_TEMPLATE
        prompt = FEEDBACK_WITH_SUMMARY_TEMPLATE.format(
            action_summary=action_summary,
            lang=LANG_FILE[best_code_language],
            best_code=best_code,
            best_feedback=best_feedback,
            latest_code=latest_code_display if isinstance(latest_code_display, str) else "",
            latest_feedback=latest_feedback_display,
            improvement_guidance=improvement_guidance,
        )

        # Also add problem statement
        problem_text = "".join(
            [p["text"] if isinstance(p, dict) and "text" in p else str(p) for p in problem_statement]
        )
        return f"[Problem Statement]\n{problem_text}\n\n{prompt}"


def update_code_history(tree_id: int, result: Optional[Result], code_language: str, code: str, feedback: str) -> None:
    """Update the code history for a specific tree"""
    with GLOBAL.codes_history_lock:
        # Initialize if no history for this tree
        if tree_id not in GLOBAL.codes_history:
            GLOBAL.codes_history[tree_id] = []

        # Add to history
        GLOBAL.codes_history[tree_id].append((result, code_language, code, feedback))

        # Update best code
        best_result, best_code_language, best_code = select_best_submission(
            GLOBAL.codes_history[tree_id], GLOBAL.score_type
        )
        GLOBAL.best_codes[tree_id] = (best_result, best_code_language, best_code)


def update_summary_history(tree_id: int, summary: Optional[str]) -> None:
    """Update the summary history for a specific tree"""
    with GLOBAL.codes_history_lock:
        # Initialize if no summary history for this tree
        if tree_id not in GLOBAL.summary_history:
            GLOBAL.summary_history[tree_id] = []

        # Add to history
        GLOBAL.summary_history[tree_id].append(summary)


def get_last_summary(tree_id: int) -> Optional[str]:
    """Get the last summary for a specific tree"""
    with GLOBAL.codes_history_lock:
        if tree_id not in GLOBAL.summary_history:
            return None

        summaries = GLOBAL.summary_history[tree_id]
        if not summaries:
            return None

        # Find the last non-None summary
        for summary in reversed(summaries):
            if summary is not None:
                return summary

        return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Parallel Best First Search")
    parser.add_argument("--problem_id", type=str, default="ahc046", help="ALE-Bench problem ID")
    parser.add_argument("--lite_version", action="store_true", help="Use lite version of the model")
    parser.add_argument("--code_language", type=str, default="cpp20", help="Programming language for code generation")
    parser.add_argument("--model", type=str, default="gemini-2.5-pro-preview-03-25", help="Model name to use")
    parser.add_argument("--num_workers", type=int, default=16, help="Number of workers for ALE Bench")
    parser.add_argument("--num_steps", type=int, default=10, help="Number of steps to execute")
    parser.add_argument("--num_samples", type=int, default=1, help="Number of child nodes to generate per expansion")
    parser.add_argument("--num_parallel", type=int, default=4, help="Number of parallel threads for node expansion")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    parser.add_argument("--num_trees", type=int, default=2, help="Number of parallel trees for search")
    parser.add_argument("--realtime", action="store_true", help="Enable realtime mode (skip private eval)")
    parser.add_argument("--duration", type=int, default=None, help="Maximum execution duration in minutes")
    parser.add_argument("--use_summary", action="store_true", help="Enable summary feature for improved prompts")
    parser.add_argument(
        "--break_on_ac", action="store_true", default=False, help="End code generation when AC is achieved"
    )
    parser.add_argument(
        "--use_domain_knowledge", action="store_true", help="Use domain knowledge prompts for solution generation"
    )
    args = parser.parse_args()

    if args.problem_id not in list_problem_ids(lite_version=args.lite_version):
        raise ValueError(f"Invalid problem id: {args.problem_id}\\nPlease choose from {list_problem_ids()}")
    if args.model not in SUPPORTED_MODELS:
        raise ValueError(f"Invalid model: {args.model}\\nPlease choose from {SUPPORTED_MODELS}")

    # Add log information
    logger.info(f"Using problem ID: {args.problem_id}")
    logger.info(f"Using model: {args.model}")
    logger.info(f"Number of trees: {args.num_trees}")
    logger.info(f"Number of steps: {args.num_steps}")
    logger.info(f"Number of samples per expansion: {args.num_samples}")
    logger.info(f"Number of parallel threads: {args.num_parallel or GLOBAL.num_expansion_threads}")
    logger.info(f"Using summary: {args.use_summary}")
    logger.info(f"Breaking on AC: {args.break_on_ac}")
    logger.info(f"Using domain knowledge prompts: {args.use_domain_knowledge}")

    # Execute parallel search
    run_parallel(
        args.problem_id,
        args.lite_version,
        args.code_language,
        args.model,
        args.num_trees,
        args.num_workers,
        args.verbose,
        args.num_steps,
        args.num_samples,
        args.num_parallel,
        args.realtime,  # Pass realtime argument
        args.duration,  # Pass duration argument
        args.use_summary,  # Pass use_summary argument
        args.break_on_ac,  # Pass break_on_ac argument
        args.use_domain_knowledge,  # Pass use_domain_knowledge argument
    )
