#!/usr/bin/env python3
import argparse
import ast
import difflib
import re
import shlex
import shutil
import sys
import traceback
import yaml
import asyncio
import threading
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from pydantic import BaseModel
from claude_llm import ClaudeLLM
from lib import run_code, limit_csv_precision

MAX_ATTEMPTS = 5
MAX_CSV_SAMPLING_ITERATIONS = 32
LLM_MAX_TEXT_BYTES = 16000000.0
LLM_MAX_TOKENS = 200000
LLM_BUFFER_TOKENS = 1500
MIN_CSV_SAMPLING_RATIO = 0.0001
CSV_REDUCTION_FACTOR = 0.8
JUDGE_MAX_TOKENS = 5000
EDITOR_MAX_TOKENS = 12000
JUDGE_TEMPERATURE = 0.1
EDITOR_TEMPERATURE = 0.2
DEPTH_BONUS_COEFFICIENT = 0.5
EXPLORATION_BONUS_COEFFICIENT = 2.0
CONTEXT_TRUNCATION_LENGTH = 200
DIVERSIFICATION_STRATEGIES = [
    ('BREAKTHROUGH', 'Make significant, high-impact changes designed to achieve major performance improvements'),
    ('AGGRESSIVE', 'Make bold structural or algorithmic changes'),
    ('AMPLIFY', 'Identify and significantly increase the most promising parameters or mechanisms'),
    ('EXPLORATORY', 'Try completely different parameter combinations'),
    ('TARGETED', 'Focus on specific high-impact parameters identified from analysis'),
    ('CONTRARIAN', 'Try approaches opposite to current trends or patterns'),
    ('BALANCED', 'Make moderate parameter changes with good risk/reward ratio'),
    ('CONSERVATIVE', 'Make small, incremental parameter adjustments'),
]
SCORING_STRATEGIES = [
    (
        'v1',
        'Use scores from 0.0 to 10.0 (higher is better)\n\nIMPORTANT: If the simulation failed (SIMULATION SUCCESS: No) OR if the simulation results contain NaN/Inf values (indicating numerical collapse), assign a score of 0.0 regardless of other factors.\n\nCOLLAPSED SYSTEM DETECTION:\n- Look for "NaN", "inf", "-inf" values in the CSV output\n- Empty cells in CSV output may indicate NaN values were written\n- Systems with numerical overflow/underflow are considered failed\n- Treat collapsed systems the same as simulation failures\n\n- 0.0-2.0: Poor/failing performance\n- 2.0-4.0: Below average performance\n- 4.0-6.0: Average performance\n- 6.0-8.0: Good performance\n- 8.0-10.0: Excellent performance',
        {'min_score': 0.0, 'max_score': 10.0},
    ),
    (
        'v2',
        'Use scores from 0.0 and up (higher is better, no upper bound)\n\nIMPORTANT: If the simulation failed (SIMULATION SUCCESS: No) OR if the simulation results contain NaN/Inf values (indicating numerical collapse), assign a score of 0.0 regardless of other factors.\n\nCOLLAPSED SYSTEM DETECTION:\n- Look for "NaN", "inf", "-inf" values in the CSV output\n- Empty cells in CSV output may indicate NaN values were written\n- Systems with numerical overflow/underflow are considered failed\n- Treat collapsed systems the same as simulation failures\n\n- Score based on BOTH absolute goal achievement AND relative performance within the search tree\n- Primary consideration: How well this node achieves the optimization goal\n- Secondary calibration: Compare against other nodes in the tree context for proper differentiation\n- Use the reference summary and node comparisons to ensure meaningful score gaps\n- Score ranges (flexible guidelines):\n  - 0.0-3.0: Poor goal achievement that fails to make meaningful progress\n  - 3.0-5.0: Below-average goal progress with minimal success\n  - 5.0-7.0: Moderate goal achievement with some measurable progress\n  - 7.0-9.0: Good goal achievement showing clear progress toward the objective\n  - 9.0+: Excellent goal achievement that meets or exceeds the target\n  - 10.0+: Exceptional goal achievement that significantly surpasses the objective\n- No upper limit - outstanding goal achievement can score 15.0, 20.0 or higher as appropriate\n- When multiple nodes have similar goal achievement, use tree context to differentiate scores\n- Ensure score gaps reflect both absolute performance differences and relative rankings',
        {'min_score': 0.0, 'max_score': float('inf')},
    ),
    (
        'v3',
        'ABSOLUTE PERFORMANCE SCORING (depth-constrained scale)\nPrimary Principle: Score based on how well this code achieves the optimization goal, regardless of other nodes in the tree.\n\nIMPORTANT: If the simulation failed (SIMULATION SUCCESS: No) OR if the simulation results contain NaN/Inf values (indicating numerical collapse), assign a score of 0.0 regardless of other factors.\n\nCOLLAPSED SYSTEM DETECTION:\n- Look for "NaN", "inf", "-inf" values in the CSV output\n- Empty cells in CSV output may indicate NaN values were written\n- Systems with numerical overflow/underflow are considered failed\n- Treat collapsed systems the same as simulation failures\n\nSCORING SCALE (use 2 decimal places like 7.25, 12.75):\n- 0.00-2.00: Failed progress - no meaningful progress toward the goal\n- 2.00-4.00: Partial progress - some improvement but far from achieving the goal\n- 4.00-6.00: Moderate progress - measurable improvement and moving toward the goal\n- 6.00-8.00: Good progress - clear advancement with substantial goal achievement\n- 8.00-10.00: Excellent progress - meets most requirements of the optimization goal\n- 10.00+: VERY GOOD performance - exceeds the goal expectations significantly\n- 20.00+: EXCEPTIONAL performance - far surpasses what the goal was asking for\n\nIMPORTANT CONSTRAINT: Maximum possible score is score <= 10.0 + 2.5 * depth. Within this limit, high scores are encouraged when performance truly merits them.\n\nHOW TO SCORE:\n1. First, check if simulation was successful (SIMULATION SUCCESS: Yes/No)\n2. If simulation failed, assign score of 0.0 immediately\n3. If simulation succeeded, evaluate how well this code achieves the goal in absolute terms\n4. Use the tree context only to calibrate what score ranges mean - don\'t let it constrain your scoring\n5. For very good performance that exceeds expectations, don\'t hesitate to give high scores\n6. When unsure between score ranges, favor the higher score if genuine progress is evident\n7. Large score increases (5+ points) are appropriate for significant improvements\n8. Always respect the depth constraint: score <= 10.0 + 2.5 * depth\n\nRemember: Judge this code\'s actual achievement of the goal, not its relative position in the search tree.',
        {'min_score': 0.0, 'max_score': float('inf')},
    ),
]
TREE_JUDGE_PROMPT_TEMPLATE = '\nANALYZE SIMULATION RESULTS FOR TREE-BASED OPTIMIZATION\n\nGOAL:\n{goal}\n\nSIMULATION RESULTS:\n{simulation_result}\n\nCURRENT CODE:\n{code}\n\nTREE CONTEXT:\n{tree_context}\n\n=== For your reference ===\n\nFormat for code in tree context:\nThe tree context shows only the differences in each reference node compared to current code.\nSince the full current code is shown above, diffs only display lines that are different in reference nodes:\n- Lines starting with "+" show what the reference node has instead of current code\n- "Code identical to current node" means no differences exist\n\nCOLLAPSED SYSTEM ASSESSMENT:\n- Check CSV output for NaN, inf, -inf values indicating numerical collapse\n- Empty values in CSV may represent NaN/missing data from overflow\n- If ANY numerical collapse is detected, treat as system failure\n- Collapsed systems indicate parameter combinations that cause instability\n- Even if simulation completed, numerical collapse = failure\n\nWhen nalyze this node\'s performance considering:\n- How well this specific code variant achieves the optimization goal\n- Performance relative to ALL other nodes in the reference (for calibrated scoring)\n- Whether this represents meaningful progress in the search space\n\nSCORING GUIDELINES:\n{scoring_guidelines}\n\nUse the tree context information to ensure your scoring is well-calibrated relative to all explored alternatives.\n\n=== Task ====\n\nPlease provide your analysis in this format:\n\nREASONING: [Start with a brief summary of key findings in the first 200 characters, then provide detailed reasoning about how well this code variant achieves the optimization goal]\n\nSCORE: [numerical score based on the scoring guidelines above]\n'
TREE_EDITOR_PROMPT_TEMPLATE = '\nSUGGEST CODE MODIFICATIONS FOR TREE SEARCH OPTIMIZATION\n\nGOAL:\n{goal}\n\nSIMULATION RESULTS:\n{simulation_result}\n\nCURRENT CODE:\n{code}\n\nTREE CONTEXT:\n{tree_context}\n\n=== For your reference ===\n\nFormat for code in tree context:\nThe tree context shows only the differences in each reference node compared to current code.\nSince the full current code is shown above, diffs only display lines that are different in reference nodes:\n- Lines starting with "+" show what the reference node has instead of current code\n- "Code identical to current node" means no differences exist\n\nHow to modify the code to better achieve the goal within the tree search context:\nFocus on:\n- Parameters within the marked BEGIN/END edit sections\n- State variables, constants, and simulation configuration\n- Derivative equations and helper expressions\n- Table values and time-dependent inputs\n- **NUMERICAL STABILITY**: If NaN/Inf detected, prioritize fixing numerical issues\n- **PARAMETER BOUNDS**: Add constraints to prevent overflow/underflow\n- **STABILITY FIRST**: Ensure system runs without collapse before optimizing\n- Learning from tree context and reference node outcomes\nWhen you modify the code:\n- Faithfully follow the CURRENT CODE. Starting from EVERYTHING in CURRENT CODE (**not only** the core code in tree context), then\n- only modify code between the "BEGIN ... / END ..." markers.\n- DO NOT change imports, wrappers, loop structure, or OUTPUT section.\n\nNUMERICAL STABILITY AND COLLAPSED SYSTEM REPAIR:\n- If simulation results contain NaN/Inf values, the system has collapsed\n- Common causes: parameter overflow, division by zero, exponential growth\n- Fix strategies:\n  * Add bounds/clamps to prevent extreme values\n  * Reduce parameter magnitudes that cause exponential growth\n  * Add numerical stability checks (max/min bounds)\n  * Modify equations to prevent division by small numbers\n  * Use more conservative parameter ranges\n- Priority: Achieve numerical stability first, then optimize performance\n- A stable system with lower performance >> an unstable system with high peaks\n\nTree Search Strategy:\n- Consider your position in the search tree\n- Learn from performance patterns in tree context\n- Use diversification strategy from tree context\n- Reference successful/failed approaches from other nodes\n\nSCORING GUIDELINES:\n{scoring_guidelines}\n\n=== Task ====\n\nPlease provide your response in this format:\n\nREASONING: [Start with a concise description of your main modification strategy in the first 200 characters, then explain detailed reasoning and assess the expected improvement using the scoring guidelines above]\n\nSELF_ASSESSED_SCORE: [numerical score based on the scoring guidelines above for your expected improvement]\n\nMODIFIED_CODE: [Complete modified Python code following all requirements - start from EVERYTHING in CURRENT CODE, then only modify code between BEGIN/END markers]\n'


def sample_csv_rows(
    csv_content: str,
    sample_ratio: float = 1.0,
    max_sampled_rows: Optional[int] = None,
    columns: Optional[List[str]] = None,
    add_header_suffix: str = '',
    add_header_to_value: bool = False,
    add_timestamp_to_value: bool = False,
    transpose: bool = False,
) -> str:
    if sample_ratio <= 0.0:
        raise ValueError('csv_sample_ratio must be > 0.0')
    import csv
    from io import StringIO

    lines = csv_content.strip().split('\n')
    if len(lines) <= 1:
        return csv_content
    reader = csv.reader(StringIO(csv_content))
    rows = list(reader)
    if len(rows) <= 1:
        return csv_content
    header = rows[0]
    data_rows = rows[1:]
    if columns:
        column_indices = []
        missing_columns = []
        for col in columns:
            try:
                column_indices.append(header.index(col))
            except ValueError:
                missing_columns.append(col)
        if missing_columns:
            raise ValueError(f'Columns not found in CSV: {missing_columns}')
        if not column_indices:
            raise ValueError('No valid columns specified')
        header = [header[i] for i in column_indices]
        data_rows = [[row[i] if i < len(row) else '' for i in column_indices] for row in data_rows]
    if max_sampled_rows and len(data_rows) > max_sampled_rows:
        effective_ratio = min(sample_ratio, max_sampled_rows / len(data_rows))
    else:
        effective_ratio = sample_ratio
    if effective_ratio < 1.0 and len(data_rows) > 1:
        step_size = max(1, round(1.0 / effective_ratio))
        sampled_indices = set(range(0, len(data_rows), step_size))
        sampled_indices.add(0)
        sampled_indices.add(len(data_rows) - 1)
        data_rows = [data_rows[i] for i in sorted(sampled_indices)]
    t_index = None
    t_values = []
    if add_timestamp_to_value:
        try:
            t_index = header.index('t')
            t_values = [row[t_index] if t_index < len(row) else '' for row in data_rows]
        except ValueError:
            raise ValueError("Column 't' not found in CSV, required for add_timestamp_to_value")
    if add_header_suffix:
        header = [f'{col}{add_header_suffix}' for col in header]
    if add_header_to_value:
        for i, row in enumerate(data_rows):
            data_rows[i] = [f'{header[j]}={cell}' if j < len(header) else cell for (j, cell) in enumerate(row)]
    if add_timestamp_to_value:
        for i, row in enumerate(data_rows):
            t_value = t_values[i] if i < len(t_values) else ''
            data_rows[i] = [f'{cell}(@t={t_value})' if j != t_index else cell for (j, cell) in enumerate(row)]
    if transpose:
        all_rows = [header] + data_rows
        if all_rows and all_rows[0]:
            max_cols = max((len(row) for row in all_rows))
            for row in all_rows:
                while len(row) < max_cols:
                    row.append('')
            transposed = [[all_rows[i][j] for i in range(len(all_rows))] for j in range(max_cols)]
            header = transposed[0]
            data_rows = transposed[1:]
    output = StringIO()
    writer = csv.writer(output)
    writer.writerow(header)
    writer.writerows(data_rows)
    return output.getvalue().strip()


class JudgeResponse(BaseModel):
    reasoning: str
    score: float

    class Config:
        extra = 'forbid'


class EditorResponse(BaseModel):
    reasoning: str
    self_assessed_score: float
    modified_code: str

    class Config:
        extra = 'forbid'


class MCTSNode(BaseModel):
    node_id: str
    parent_id: Optional[str] = None
    children: List[str] = []
    depth: int
    sibling_index: int
    adding_order: int
    strategy_name: str = ''
    code: str
    node_value: Optional[float] = None
    judge_reasoning: str = ''
    csv_output: str = ''
    simulation_success: Optional[bool] = None
    simulation_error: Optional[str] = None
    expansion_count: int = 0
    expansion_score: Optional[float] = None
    action_description: str = ''
    editor_reasoning: str = ''

    class Config:
        extra = 'forbid'

    @classmethod
    def create_child_id(cls, parent_id: str, total_children: int) -> str:
        return f'{parent_id}_{total_children}'


class MCTSTreeState(BaseModel):
    nodes: Dict[str, MCTSNode] = {}
    iteration: int = 0
    metadata: Dict[str, Any] = {}

    class Config:
        extra = 'forbid'


def build_simulation_result(success: bool, csv_output: str, error: str) -> str:
    success_text = 'Yes' if success else 'No'
    result = f'SIMULATION SUCCESS: {success_text}'
    if error and error.strip():
        result += f'\nERROR: {error}'
    csv_text = csv_output if success and csv_output else 'No output generated due to simulation failure'
    result += f'\nCSV OUTPUT:\n{csv_text}'
    return result


def preprocess_goal(
    goal: str,
    logger=None,
    csv_add_header_to_value: bool = False,
    csv_add_timestamp_to_value: bool = False,
    csv_transpose: bool = False,
) -> str:
    def parse_load_params(param_string: str) -> dict:
        params = [p.strip() for p in param_string.split(',')]
        result = {'filepath': params[0], 'type': None, 'max_sampled_lines': None, 'columns': None}
        for param in params[1:]:
            if '=' in param:
                (key, value) = param.split('=', 1)
                if key.strip() == 'type':
                    result['type'] = value.strip()
                elif key.strip() == 'max_sampled_lines':
                    result['max_sampled_lines'] = int(value.strip())
                elif key.strip() == 'columns':
                    result['columns'] = [col.strip() for col in value.split('/')]
        return result

    def load_csv_in_goal(
        filepath: str, max_sampled_lines: Optional[int] = None, columns: Optional[List[str]] = None
    ) -> str:
        file_path = Path(filepath)
        if not file_path.exists():
            return 'File not found'
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                csv_content = f.read().strip()
            if not csv_content:
                return 'Empty file'
            return sample_csv_rows(
                csv_content=csv_content,
                sample_ratio=1.0,
                max_sampled_rows=max_sampled_lines,
                columns=columns,
                add_header_suffix='_IN_GOAL_REFERENCE',
                add_header_to_value=csv_add_header_to_value,
                add_timestamp_to_value=csv_add_timestamp_to_value,
                transpose=csv_transpose,
            )
        except ValueError as e:
            return f'Error: {e}'
        except Exception as e:
            return f'Error reading CSV: {e}'

    load_pattern = '@load\\(([^)]+)\\)'
    matches = re.findall(load_pattern, goal)
    if not matches:
        if logger:
            logger.write('No @load() clauses found in goal\n')
        return goal
    if logger:
        logger.write(f'Found {len(matches)} @load() clauses to process\n')
    enhanced_goal = goal
    for match in matches:
        try:
            if logger:
                logger.write(f'Processing @load({match})\n')
            params = parse_load_params(match)
            if params['type'] != 'csv':
                error_msg = f'''Error: type must be csv, got "{params['type']}"'''
                if logger:
                    logger.write(f'@load({match}) - {error_msg}\n')
                enhanced_goal += f'\n\n=== GOAL REFERENCE DATA (NOT CURRENT SYSTEM) ===\n@load({match}) - {error_msg}\n=== END GOAL REFERENCE DATA ==='
                continue
            csv_content = load_csv_in_goal(params['filepath'], params['max_sampled_lines'], params['columns'])
            if csv_content.startswith('Error:') or csv_content in ['File not found', 'Empty file']:
                if logger:
                    logger.write(f'@load({match}) - {csv_content}\n')
            elif logger:
                logger.write(f'@load({match}) - Successfully loaded CSV ({len(csv_content)} characters)\n')
            enhanced_goal += f'\n\n=== GOAL REFERENCE DATA (NOT CURRENT SYSTEM) ===\n@load({match}) contains:\n{csv_content}\n=== END GOAL REFERENCE DATA ==='
        except Exception as e:
            error_msg = f'Error: {e}'
            if logger:
                logger.write(f'@load({match}) - {error_msg}\n')
            enhanced_goal += f'\n\n=== GOAL REFERENCE DATA (NOT CURRENT SYSTEM) ===\n@load({match}) - {error_msg}\n=== END GOAL REFERENCE DATA ==='
    enhanced_goal += "\n\n=== IMPORTANT CLARIFICATION ===\nAll data loaded above using @load() statements are REFERENCE MATERIALS for understanding the GOAL.\nThese are NOT part of the current system being optimized.\nThe current system's code and simulation results are provided separately below.\n=== END CLARIFICATION ==="
    if logger:
        logger.write(f'Enhanced goal length: {len(enhanced_goal)} characters\n')
    return enhanced_goal


def sanitize_name(name: str) -> str:
    name = Path(name).stem
    sanitized = re.sub('[^a-zA-Z0-9_-]', '_', name)
    sanitized = re.sub('_+', '_', sanitized)
    return sanitized.strip('_')[:50]


def setup_logging(run_dir: Path):
    class LoggerTee:
        def __init__(self, log_path: Path):
            self.terminal = sys.stdout
            self.log_file = open(log_path, 'w', encoding='utf-8')
            self.lock = threading.Lock()

        def write(self, message: str):
            with self.lock:
                self.terminal.write(message)
                self.log_file.write(message)
                self.log_file.flush()

        def flush(self):
            with self.lock:
                self.terminal.flush()
                self.log_file.flush()

        def close(self):
            with self.lock:
                self.log_file.close()

    logger = LoggerTee(run_dir / 'mcts_logging.log')
    return logger


def validate_python_code(code: str) -> Tuple[bool, str]:
    try:
        ast.parse(code)
        return (True, '')
    except SyntaxError as e:
        return (False, f'Syntax error at line {e.lineno}: {e.msg}')
    except Exception as e:
        return (False, f'Parse error: {str(e)}')


class DialogueSaver:
    def __init__(self, output_dir: Path, run_timestamp: str):
        self.output_dir = output_dir
        self.dialogues_dir = output_dir / 'run' / run_timestamp / 'dialogues'
        self.dialogues_dir.mkdir(parents=True, exist_ok=True)

    def save_prompt_to_file(
        self,
        prompt_type: str,
        node_id: str,
        prompt_content: str,
        context: dict,
        attempt: int,
        child_index: Optional[int] = None,
    ):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        if child_index is not None:
            identifier = f'{prompt_type}_node_{node_id}_child_{child_index}_attempt_{attempt}'
        else:
            identifier = f'{prompt_type}_node_{node_id}_attempt_{attempt}'
        base_filename = f'{timestamp}_{identifier}'
        prompt_file = self.dialogues_dir / f'{base_filename}.txt'
        prompt_file.write_text(prompt_content)
        context_file = self.dialogues_dir / f'{base_filename}_context.yaml'
        context_file.write_text(yaml.dump(context, default_flow_style=False))

    def save_response_to_file(
        self,
        prompt_type: str,
        node_id: str,
        response: str,
        parsed_data: dict,
        success: bool,
        attempt: int,
        child_index: Optional[int] = None,
    ):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        if child_index is not None:
            identifier = f'{prompt_type}_node_{node_id}_child_{child_index}_attempt_{attempt}'
        else:
            identifier = f'{prompt_type}_node_{node_id}_attempt_{attempt}'
        response_filename = f'{timestamp}_{identifier}_response'
        response_timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        header = f'=== LLM RESPONSE ===\nTimestamp: {response_timestamp}\nStatus: {("SUCCESS" if success else "FAILED")}\nParsed Successfully: {("Yes" if parsed_data else "No")}'
        raw_section = f'=== RAW RESPONSE ===\n{response}'
        if parsed_data:
            parsed_details = '\n'.join([f'{key}: {value}' for (key, value) in parsed_data.items()])
            parsed_section = f'=== PARSED DATA ===\n{parsed_details}'
        else:
            parsed_section = '=== PARSED DATA ===\nFailed to parse response'
        metadata = f'=== METADATA ===\nResponse length: {len(response):,} characters'
        content = f'{header}\n\n{raw_section}\n\n{parsed_section}\n\n{metadata}'
        response_file = self.dialogues_dir / f'{response_filename}.txt'
        response_file.write_text(content)

    def save_failed_response(
        self,
        prompt_type: str,
        node_id: str,
        attempt: int,
        error: str,
        prompt_length: int,
        child_index: Optional[int] = None,
    ):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        if child_index is not None:
            identifier = f'{prompt_type}_node_{node_id}_child_{child_index}_attempt_{attempt}'
        else:
            identifier = f'{prompt_type}_node_{node_id}_attempt_{attempt}'
        failure_filename = f'{timestamp}_{identifier}_failed'
        failure_timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        if child_index is not None:
            node_identifier = f'Node: {node_id}, Child: {child_index}, Attempt: {attempt}/[0-{MAX_ATTEMPTS - 1}]'
        else:
            node_identifier = f'Node: {node_id}, Attempt: {attempt}/[0-{MAX_ATTEMPTS - 1}]'
        header = f'=== FAILED LLM RESPONSE ===\nTimestamp: {failure_timestamp}\nType: {prompt_type.upper()} FAILED - {node_identifier}\nError: {error}\nPrompt length: {prompt_length:,} characters'
        termination_info = ''
        if attempt == MAX_ATTEMPTS - 1:
            termination_info = '\nRESULT: ValueError will be raised - optimization terminated'
        metadata = f'=== METADATA ===\nFinal attempt: {attempt == MAX_ATTEMPTS - 1}'
        content = f'{header}\n{termination_info}\n\n{metadata}'
        failure_file = self.dialogues_dir / f'{failure_filename}.txt'
        failure_file.write_text(content)


class MCTSOptimizer:
    def __init__(
        self,
        claude_llm,
        runner,
        goal: str,
        logger,
        output_dir: Path,
        run_timestamp: str,
        expansion_width: int = 1,
        expand_internal: int = 0,
        max_expansions_per_node: int = 3,
        csv_sample_ratio: float = 0.5,
        csv_add_header_to_value: bool = False,
        csv_add_timestamp_to_value: bool = False,
        csv_transpose: bool = False,
        scoring_strategy: str = 'v2',
    ):
        if csv_sample_ratio <= 0.0:
            raise ValueError('csv_sample_ratio must be > 0.0')
        if csv_sample_ratio > 1.0:
            raise ValueError('csv_sample_ratio must be <= 1.0')
        self.state = MCTSTreeState()
        self.root_id: str = '0'
        self.expansion_width = expansion_width
        self.expand_internal = expand_internal
        self.max_expansions_per_node = max_expansions_per_node
        self.output_dir = output_dir
        self.csv_sample_ratio = csv_sample_ratio
        self.csv_add_header_to_value = csv_add_header_to_value
        self.csv_add_timestamp_to_value = csv_add_timestamp_to_value
        self.csv_transpose = csv_transpose
        self.scoring_strategy = scoring_strategy
        valid_strategies = [name for (name, _, _) in SCORING_STRATEGIES]
        if scoring_strategy not in valid_strategies:
            raise ValueError(f"Invalid scoring strategy '{scoring_strategy}'. Valid options: {valid_strategies}")
        self.claude = claude_llm
        self.runner = runner
        self.goal = goal
        self.logger = logger
        self.dialogue_saver = DialogueSaver(output_dir, run_timestamp)

    def get_scoring_guidelines(self) -> str:
        for name, guidelines, _ in SCORING_STRATEGIES:
            if name == self.scoring_strategy:
                return guidelines
        raise ValueError(f"Scoring strategy '{self.scoring_strategy}' not found in SCORING_STRATEGIES")

    def get_scoring_bounds(self) -> dict:
        for name, _, bounds in SCORING_STRATEGIES:
            if name == self.scoring_strategy:
                return bounds
        raise ValueError(f"Scoring strategy '{self.scoring_strategy}' not found in SCORING_STRATEGIES")

    def compute_expansion_score(self, node: MCTSNode) -> float:
        if node.node_value is None:
            return 0.0
        exploitation = node.node_value
        depth_bonus = DEPTH_BONUS_COEFFICIENT * node.depth
        exploration_bonus = EXPLORATION_BONUS_COEFFICIENT / (1 + node.expansion_count)
        return exploitation + depth_bonus + exploration_bonus

    def load_tree_state(self, iteration: int):
        if iteration <= 0:
            raise ValueError(f'Invalid iteration number: {iteration}. Must be > 0.')
        metadata = self.load_tree_state_metadata()
        if iteration not in metadata.get('state_files', {}):
            available_iters = sorted(metadata.get('state_files', {}).keys())
            raise ValueError(f'No tree state found for iteration {iteration}. Available iterations: {available_iters}')
        filename = metadata['state_files'][iteration]
        state_file = self.output_dir / 'tree_state' / filename
        if not state_file.exists():
            raise ValueError(f'Tree state file not found: {state_file}')
        try:
            with open(state_file, 'r') as f:
                state_data = yaml.safe_load(f)
            self.state = MCTSTreeState.model_validate(state_data)
            self.compute_metadata()
            self.logger.write(
                f'Loaded tree state from iteration {iteration} ({filename}): {len(self.state.nodes)} nodes\n'
            )
        except Exception as e:
            raise ValueError(f'Failed to load tree state from iteration {iteration}: {e}')

    def compute_metadata(self):
        evaluated_nodes = [n for n in self.state.nodes.values() if n.node_value is not None]
        if evaluated_nodes:
            best_node = max(evaluated_nodes, key=lambda n: n.node_value if n.node_value is not None else float('-inf'))
            if best_node.node_value is None:
                raise ValueError(f'Best node {best_node.node_id} has None value but should be evaluated')
            best_score = best_node.node_value
            best_node_id = best_node.node_id
        else:
            best_score = 0.0
            best_node_id = self.root_id
        scores = [n.node_value for n in evaluated_nodes if n.node_value is not None]
        score_distribution = {}
        if scores:
            score_distribution = {
                'min': min(scores),
                'max': max(scores),
                'avg': sum(scores) / len(scores),
                'count': len(scores),
            }
        else:
            score_distribution = {'min': 0, 'max': 0, 'avg': 0, 'count': 0}
        for node in self.state.nodes.values():
            node.expansion_score = self.compute_expansion_score(node)
        self.state.metadata.update(
            {
                'best_score': best_score,
                'best_node_id': best_node_id,
                'total_nodes': len(self.state.nodes),
                'max_depth': max((node.depth for node in self.state.nodes.values())) if self.state.nodes else 0,
                'timestamp': datetime.now().isoformat(),
                'expandable_nodes': len([nid for nid in self.state.nodes.keys() if self.can_expand_node(nid)]),
                'leaf_nodes': len([nid for nid in self.state.nodes.keys() if len(self.state.nodes[nid].children) == 0]),
                'score_distribution': score_distribution,
            }
        )

    async def adaptive_sample_and_build_prompt_async(
        self, prompt_builder_func: Callable[[float], str], max_prompt_tokens: int, context_name: str = 'sampling'
    ) -> Tuple[str, float]:
        current_ratio = self.csv_sample_ratio
        iteration = 0
        while True:
            prompt_content = prompt_builder_func(current_ratio)
            prompt_bytes = len(prompt_content.encode('utf-8'))
            if prompt_bytes > LLM_MAX_TEXT_BYTES * 0.95:
                current_ratio *= CSV_REDUCTION_FACTOR
                iteration += 1
                if iteration >= MAX_CSV_SAMPLING_ITERATIONS:
                    break
                continue
            prompt_tokens = await self.claude.count_tokens_async(prompt_content)
            if prompt_tokens <= max_prompt_tokens:
                self.logger.write(
                    f'{context_name} final adaptive sampling: ratio {current_ratio:.3f} (tokens: {prompt_tokens:,})\n'
                )
                return (prompt_content, current_ratio)
            else:
                current_ratio *= CSV_REDUCTION_FACTOR
            iteration += 1
            if iteration >= MAX_CSV_SAMPLING_ITERATIONS:
                break
        current_ratio = MIN_CSV_SAMPLING_RATIO
        prompt_content = prompt_builder_func(current_ratio)
        final_tokens = await self.claude.count_tokens_async(prompt_content)
        self.logger.write(
            f'{context_name} WARNING: Max iterations reached, using minimum ratio {current_ratio:.3f} (tokens: {final_tokens:,})\n'
        )
        return (prompt_content, current_ratio)

    def can_expand_node(self, node_id: str) -> bool:
        node = self.state.nodes[node_id]
        return node.expansion_count < self.max_expansions_per_node and (
            self.expand_internal != 0 or len(node.children) == 0
        )

    def initialize_root(self, initial_code: str):
        root_node = MCTSNode(
            node_id='0',
            parent_id=None,
            children=[],
            depth=0,
            sibling_index=0,
            adding_order=0,
            strategy_name='',
            code=initial_code,
            action_description='Initial code',
        )
        self.state.nodes['0'] = root_node
        self.logger.write('Initializing root node...\n')
        root_score = asyncio.run(self._evaluate_node_with_judge_async('0', root_node))
        root_node.node_value = root_score
        self.compute_metadata()
        self.logger.write(f'Root node initialized with score: {root_score:.3f}\n')
        self.log_node_creation('0')

    def select_node_for_expansion(self) -> Optional[str]:
        expandable_nodes = [node_id for node_id in self.state.nodes.keys() if self.can_expand_node(node_id)]
        inexpandable_nodes = [node_id for node_id in self.state.nodes.keys() if not self.can_expand_node(node_id)]
        if not expandable_nodes:
            return None
        expandable_scores = []
        for node_id in expandable_nodes:
            node = self.state.nodes[node_id]
            assert node.node_value is not None
            total_selection_score = self.compute_expansion_score(node)
            exploitation = node.node_value
            depth_bonus = DEPTH_BONUS_COEFFICIENT * node.depth
            exploration_bonus = EXPLORATION_BONUS_COEFFICIENT / (1 + node.expansion_count)
            expandable_scores.append(
                {
                    'node_id': node_id,
                    'exploitation': exploitation,
                    'depth_bonus': depth_bonus,
                    'exploration_bonus': exploration_bonus,
                    'total_selection_score': total_selection_score,
                    'depth': node.depth,
                    'expansion_count': node.expansion_count,
                }
            )
        inexpandable_info = []
        for node_id in inexpandable_nodes:
            node = self.state.nodes[node_id]
            if node.node_value is not None:
                inexpandable_info.append(
                    {
                        'node_id': node_id,
                        'node_value': node.node_value,
                        'depth': node.depth,
                        'expansion_count': node.expansion_count,
                        'reason': self._get_inexpandable_reason(node),
                    }
                )
        expandable_scores.sort(key=lambda x: x['total_selection_score'], reverse=True)
        inexpandable_info.sort(key=lambda x: x['node_value'], reverse=True)
        self.logger.write(f'\nNode selection analysis ({len(expandable_nodes)} expandable nodes):\n')
        for i, score_info in enumerate(expandable_scores):
            rank_indicator = '>>> ' if i == 0 else '    '
            self.logger.write(
                f'{rank_indicator}{score_info["node_id"]}: total={score_info["total_selection_score"]:.3f} (exploit={score_info["exploitation"]:.3f} + depth={score_info["depth_bonus"]:.3f} + explore={score_info["exploration_bonus"]:.3f}) [depth={score_info["depth"]}, expansions={score_info["expansion_count"]}]\n'
            )
        if inexpandable_info:
            self.logger.write(f'\nInexpandable nodes ({len(inexpandable_info)} nodes):\n')
            for info in inexpandable_info:
                self.logger.write(
                    f'    {info["node_id"]}: value={info["node_value"]:.3f} [depth={info["depth"]}, expansions={info["expansion_count"]}] - {info["reason"]}\n'
                )
        selected = expandable_scores[0]['node_id']
        self.logger.write(f'Selected node {selected} for expansion\n')
        return selected

    def _get_inexpandable_reason(self, node) -> str:
        reasons = []
        if node.expansion_count >= self.max_expansions_per_node:
            reasons.append(f'max expansions reached ({node.expansion_count}/{self.max_expansions_per_node})')
        if self.expand_internal == 0 and len(node.children) > 0:
            reasons.append(f'internal node expansion disabled ({len(node.children)} children)')
        if reasons:
            return '; '.join(reasons)
        else:
            return 'unknown reason'

    def extract_core_simulation_code(self, full_code: str) -> str:
        lines = full_code.split('\n')
        start_idx = None
        end_idx = None
        for i, line in enumerate(lines):
            if 'BEGIN WHOLE SYSTEM' in line:
                start_idx = i + 1
            elif 'END WHOLE SYSTEM' in line:
                end_idx = i
                break
        if start_idx is None or end_idx is None:
            start_idx = 0
            end_idx = len(lines)
        core_lines = []
        for line in lines[start_idx:end_idx]:
            stripped = line.strip()
            if stripped and (not stripped.startswith('#')):
                if '#' in line:
                    code_part = line[: line.index('#')].rstrip()
                    if code_part.strip():
                        core_lines.append(code_part)
                else:
                    core_lines.append(line)
        return '\n'.join(core_lines)

    def get_code_diff_preview(self, reference_node_id: str, current_node_id: str) -> str:
        if reference_node_id not in self.state.nodes or current_node_id not in self.state.nodes:
            return 'Node not found'
        reference_code = self.extract_core_simulation_code(self.state.nodes[reference_node_id].code)
        current_code = self.extract_core_simulation_code(self.state.nodes[current_node_id].code)
        if reference_code == current_code:
            return 'Code identical to current node'
        reference_lines = reference_code.splitlines(keepends=True)
        current_lines = current_code.splitlines(keepends=True)
        diff_lines = list(
            difflib.unified_diff(
                current_lines,
                reference_lines,
                fromfile=f'current_node_{current_node_id}',
                tofile=f'reference_node_{reference_node_id}',
                lineterm='',
                n=0,
            )
        )
        if not diff_lines:
            return 'Code identical to current node'
        filtered_lines = []
        for line in diff_lines[2:] if len(diff_lines) > 2 else diff_lines:
            if line.startswith('+'):
                filtered_lines.append(line)
        if not filtered_lines:
            return 'Code identical to current node'
        return '\n'.join(filtered_lines)

    def construct_reference_context(self, current_node_id: str) -> str:
        if not self.state.nodes:
            return 'No reference nodes available.'
        sorted_node_ids = sorted(self.state.nodes.keys(), key=lambda x: (len(x.split('_')), x))
        header = f"REFERENCE - ALL NODES IN SEARCH TREE:\n{'=' * 50}\nNote: Code sections show diffs from current node ({current_node_id}) to each reference node's code."
        node_entries = []
        for node_id in sorted_node_ids:
            node = self.state.nodes[node_id]
            if node.node_value is not None:
                score_str = f'{node.node_value:.2f}'
            else:
                score_str = 'not evaluated yet'
            code_diff = self.get_code_diff_preview(node_id, current_node_id)
            if node.judge_reasoning:
                reasoning_preview = (
                    node.judge_reasoning[:CONTEXT_TRUNCATION_LENGTH] + '...'
                    if len(node.judge_reasoning) > CONTEXT_TRUNCATION_LENGTH
                    else node.judge_reasoning
                )
            else:
                reasoning_preview = 'not available (evaluation pending)' if node.node_value is None else 'not available'
            node_entries.append(
                f'\nNode {node_id} (depth {node.depth}):\n  Score: {score_str}\n  Code Diff: {code_diff}\n  Judge Reasoning: {reasoning_preview}\n  Action: {node.action_description}'
            )
        evaluated_nodes = [n for n in self.state.nodes.values() if n.node_value is not None]
        if evaluated_nodes:
            scores = [n.node_value for n in evaluated_nodes if n.node_value is not None]
            summary = f'\n{"=" * 50}\nREFERENCE SUMMARY:\n  Total nodes: {len(self.state.nodes)}\n  Evaluated nodes: {len(evaluated_nodes)}\n  Score range: {min(scores):.2f} - {max(scores):.2f}\n  Average score: {sum(scores) / len(scores):.2f}\n  Best node: {self.state.metadata.get("best_node_id", "0")} ({self.state.metadata.get("best_score", 0.0):.2f})'
            return header + ''.join(node_entries) + summary
        else:
            return header + ''.join(node_entries)

    def get_path_to_root(self, node_id: str) -> List[str]:
        path = []
        current_id = node_id
        while current_id is not None:
            path.append(current_id)
            current_node = self.state.nodes[current_id]
            current_id = current_node.parent_id
        return path

    def get_path_display(self, node_id: str) -> str:
        path = self.get_path_to_root(node_id)
        path_scores = []
        for pid in reversed(path):
            pnode = self.state.nodes[pid]
            score_str = f'{pnode.node_value:.2f}' if pnode.node_value else 'N/A'
            path_scores.append(f'{pid}({score_str})')
        return ' → '.join(path_scores)

    def get_diversification_strategy(self, child_index: int) -> str:
        (strategy_name, strategy_desc) = DIVERSIFICATION_STRATEGIES[child_index % len(DIVERSIFICATION_STRATEGIES)]
        return f'{strategy_name} APPROACH: {strategy_desc}'

    def get_strategy_name(self, child_index: int) -> str:
        (strategy_name, _) = DIVERSIFICATION_STRATEGIES[child_index % len(DIVERSIFICATION_STRATEGIES)]
        return strategy_name

    def construct_tree_context_for_judge(self, node_id: str) -> str:
        node = self.state.nodes[node_id]
        position_info = (
            f'Node Position: {node_id} (depth {node.depth})\nPath from Root: {self.get_path_display(node_id)}'
        )
        if node.parent_id is None:
            parent_info = 'This is the root node - no parent.'
            sibling_info = 'No siblings - this is the root node.'
        else:
            parent = self.state.nodes[node.parent_id]
            parent_score_str = f'{parent.node_value:.2f}' if parent.node_value is not None else 'not evaluated yet'
            if parent.node_value is not None and node.node_value is not None:
                score_change_str = f'Score change: {node.node_value - parent.node_value:+.2f}'
            else:
                score_change_str = 'Score change: not available (evaluation pending)'
            parent_info = f'Parent: {node.parent_id} scored {parent_score_str}\nParent reasoning: {(parent.judge_reasoning[:CONTEXT_TRUNCATION_LENGTH] + "..." if parent.judge_reasoning else "not available")}\nModification made: {(node.action_description[:CONTEXT_TRUNCATION_LENGTH] + "..." if node.action_description else "not available")}\n{score_change_str}'
            siblings = [self.state.nodes[child_id] for child_id in parent.children if child_id != node_id]
            if not siblings:
                sibling_info = 'No siblings yet.'
            else:
                sibling_details = []
                for sib in siblings:
                    if sib.node_value is not None:
                        score_info = f'{sib.node_value:.2f}'
                    else:
                        score_info = 'not evaluated yet'
                    sibling_details.append(
                        f'  {sib.node_id}: {score_info} - {sib.action_description[:CONTEXT_TRUNCATION_LENGTH]}...'
                    )
                sibling_info = f'Sibling Nodes:\n{chr(10).join(sibling_details)}'
        return f'{position_info}\n{parent_info}\n{sibling_info}\n\n{self.construct_reference_context(node_id)}'

    def construct_tree_context_for_editor(self, parent_id: str, child_index: int) -> str:
        parent = self.state.nodes[parent_id]
        diversification_strategy = self.get_diversification_strategy(child_index)
        scores_info = f'CURRENT_SCORE: {parent.node_value:.2f}\nPARENT_SCORE: {parent.node_value:.2f}\nGLOBAL_BEST_SCORE: {self.state.metadata.get("best_score", 0.0):.2f}\nNODE_POSITION: {parent_id} (child {child_index}/{self.expansion_width})\nDIVERSIFICATION_STRATEGY: {diversification_strategy}\nPath from Root: {self.get_path_display(parent_id)}'
        siblings = [self.state.nodes[cid] for cid in parent.children]
        if siblings:
            sibling_details = []
            for sib in siblings:
                sibling_desc = f'{sib.node_value:.2f}' if sib.node_value else 'N/A'
                sibling_details.append(
                    f'  {sib.node_id}: {sibling_desc} - {sib.action_description[:CONTEXT_TRUNCATION_LENGTH]}...'
                )
            sibling_details = '\n'.join(sibling_details)
            sibling_info = f'Existing siblings:\n{sibling_details}'
        else:
            sibling_info = ''
        strategy_guidance = f'Tree Search Strategy:\n- Consider your position in the search tree ({parent_id})\n- Learn from parent performance ({parent.node_value:.2f})\n- Aim to exceed global best ({self.state.metadata.get("best_score", 0.0):.2f})\n- Use diversification strategy: {diversification_strategy}\n- Reference successful/failed approaches from other nodes'
        return (
            f'{scores_info}\n\n{sibling_info}\n\n{strategy_guidance}\n\n{self.construct_reference_context(parent_id)}'
        )

    def construct_judge_context(self, node_id: str) -> dict:
        node = self.state.nodes[node_id]
        success = node.simulation_success if node.simulation_success is not None else True
        error = node.simulation_error if node.simulation_error else ''
        simulation_result = build_simulation_result(success, node.csv_output, error)
        return {
            'goal': self.goal,
            'code': node.code,
            'simulation_result': simulation_result,
            'tree_context': self.construct_tree_context_for_judge(node_id),
            'scoring_guidelines': self.get_scoring_guidelines(),
        }

    def construct_editor_context(self, parent_id: str, child_index: int) -> dict:
        parent = self.state.nodes[parent_id]
        success = parent.simulation_success if parent.simulation_success is not None else True
        error = parent.simulation_error if parent.simulation_error else ''
        simulation_result = build_simulation_result(success, parent.csv_output, error)
        return {
            'goal': self.goal,
            'code': parent.code,
            'simulation_result': simulation_result,
            'tree_context': self.construct_tree_context_for_editor(parent_id, child_index),
            'scoring_guidelines': self.get_scoring_guidelines(),
        }

    async def _evaluate_node_with_judge_async(self, node_id: str, node: MCTSNode) -> float:
        (success, csv_output, error) = self.runner.run_simulation_from_code(node.code)
        node.simulation_success = success
        node.simulation_error = error if error else None
        if success:
            node.csv_output = limit_csv_precision(csv_output)
            self.logger.write(f'Simulation successful for {node_id}\n')
        else:
            node.csv_output = ''
            self.logger.write(f'Simulation failed for {node_id}: {error}\n')
        simulation_result = build_simulation_result(success, node.csv_output, error)

        def build_judge_prompt(ratio: float) -> str:
            judge_context = self.construct_judge_context(node_id)
            if success:
                sampled_csv = sample_csv_rows(
                    node.csv_output,
                    ratio,
                    add_header_to_value=self.csv_add_header_to_value,
                    add_timestamp_to_value=self.csv_add_timestamp_to_value,
                    transpose=self.csv_transpose,
                )
                judge_context['simulation_result'] = build_simulation_result(success, sampled_csv, error)
            else:
                judge_context['simulation_result'] = simulation_result
            return TREE_JUDGE_PROMPT_TEMPLATE.format(**judge_context)

        max_tokens = LLM_MAX_TOKENS - LLM_BUFFER_TOKENS - JUDGE_MAX_TOKENS
        (prompt_content, final_ratio) = await self.adaptive_sample_and_build_prompt_async(
            build_judge_prompt, max_tokens, 'Judge'
        )
        for attempt in range(MAX_ATTEMPTS):
            try:
                judge_context = self.construct_judge_context(node_id)
                if success:
                    sampled_csv = sample_csv_rows(
                        node.csv_output,
                        final_ratio,
                        add_header_to_value=self.csv_add_header_to_value,
                        add_timestamp_to_value=self.csv_add_timestamp_to_value,
                        transpose=self.csv_transpose,
                    )
                    judge_context['simulation_result'] = build_simulation_result(success, sampled_csv, error)
                else:
                    judge_context['simulation_result'] = simulation_result
                self.dialogue_saver.save_prompt_to_file('judge', node_id, prompt_content, judge_context, attempt)
                judge_response = await self.claude.chat_structured_async(
                    [{'role': 'user', 'content': prompt_content}],
                    JudgeResponse,
                    max_tokens=JUDGE_MAX_TOKENS,
                    temperature=JUDGE_TEMPERATURE,
                )
                bounds = self.get_scoring_bounds()
                if (
                    not judge_response.reasoning
                    or judge_response.score < bounds['min_score']
                    or judge_response.score > bounds['max_score']
                ):
                    if attempt < MAX_ATTEMPTS - 1:
                        self.logger.write(
                            f'Judge async response validation failed (attempt {attempt + 1}): invalid reasoning or score\n'
                        )
                        continue
                parsed_data = {'reasoning': judge_response.reasoning, 'score': judge_response.score}
                self.dialogue_saver.save_response_to_file(
                    'judge', node_id, str(judge_response.model_dump()), parsed_data, True, attempt
                )
                node.judge_reasoning = judge_response.reasoning
                return judge_response.score
            except Exception as e:
                self.dialogue_saver.save_failed_response('judge', node_id, attempt, str(e), len(prompt_content))
                self.logger.write(f'Judge LLM async error on attempt {attempt + 1}: {e}\n')
                if attempt == MAX_ATTEMPTS - 1:
                    raise ValueError(f'Judge LLM async failed after {MAX_ATTEMPTS} attempts for node {node_id}: {e}')
                continue
        raise ValueError(f'Judge LLM async failed after {MAX_ATTEMPTS} attempts for node {node_id}')

    async def _generate_child_code_async(
        self, parent_id: str, child_index: int, child_id: str
    ) -> Tuple[str, str, str, str]:
        def build_editor_prompt(ratio: float) -> str:
            editor_context = self.construct_editor_context(parent_id, child_index)
            parent_node = self.state.nodes[parent_id]
            success = parent_node.simulation_success if parent_node.simulation_success is not None else True
            error = parent_node.simulation_error if parent_node.simulation_error else ''
            if success:
                sampled_csv = sample_csv_rows(
                    parent_node.csv_output,
                    ratio,
                    add_header_to_value=self.csv_add_header_to_value,
                    add_timestamp_to_value=self.csv_add_timestamp_to_value,
                    transpose=self.csv_transpose,
                )
                editor_context['simulation_result'] = build_simulation_result(success, sampled_csv, error)
            return TREE_EDITOR_PROMPT_TEMPLATE.format(**editor_context)

        max_tokens = LLM_MAX_TOKENS - LLM_BUFFER_TOKENS - EDITOR_MAX_TOKENS
        (prompt_content, final_ratio) = await self.adaptive_sample_and_build_prompt_async(
            build_editor_prompt, max_tokens, 'Editor'
        )
        for attempt in range(MAX_ATTEMPTS):
            try:
                editor_context = self.construct_editor_context(parent_id, child_index)
                parent_node = self.state.nodes[parent_id]
                success = parent_node.simulation_success if parent_node.simulation_success is not None else True
                error = parent_node.simulation_error if parent_node.simulation_error else ''
                if success:
                    sampled_csv = sample_csv_rows(
                        parent_node.csv_output,
                        final_ratio,
                        add_header_to_value=self.csv_add_header_to_value,
                        add_timestamp_to_value=self.csv_add_timestamp_to_value,
                        transpose=self.csv_transpose,
                    )
                    editor_context['simulation_result'] = build_simulation_result(success, sampled_csv, error)
                self.dialogue_saver.save_prompt_to_file(
                    'editor', parent_id, prompt_content, editor_context, attempt, child_index
                )
                editor_response = await self.claude.chat_structured_async(
                    [{'role': 'user', 'content': prompt_content}],
                    EditorResponse,
                    max_tokens=EDITOR_MAX_TOKENS,
                    temperature=EDITOR_TEMPERATURE,
                )
                action_desc = editor_response.reasoning[:CONTEXT_TRUNCATION_LENGTH]
                if not editor_response.reasoning or not editor_response.modified_code:
                    if attempt < MAX_ATTEMPTS - 1:
                        self.logger.write(
                            f'Editor async response validation failed (attempt {attempt + 1}): empty reasoning or code\n'
                        )
                        continue
                (is_valid, error_msg) = validate_python_code(editor_response.modified_code)
                if is_valid:
                    parsed_data = {
                        'action_desc': action_desc,
                        'reasoning': editor_response.reasoning,
                        'modified_code': editor_response.modified_code,
                        'self_assessed_score': editor_response.self_assessed_score,
                    }
                    self.dialogue_saver.save_response_to_file(
                        'editor', parent_id, str(editor_response.model_dump()), parsed_data, True, attempt, child_index
                    )
                    return (child_id, editor_response.modified_code, action_desc, editor_response.reasoning)
                elif attempt < MAX_ATTEMPTS - 1:
                    parsed_data = {
                        'action_desc': action_desc,
                        'reasoning': editor_response.reasoning,
                        'modified_code': editor_response.modified_code,
                        'validation_error': error_msg,
                    }
                    self.dialogue_saver.save_response_to_file(
                        'editor', parent_id, str(editor_response.model_dump()), parsed_data, False, attempt, child_index
                    )
                    self.logger.write(f'Generated code has syntax error (async attempt {attempt + 1}): {error_msg}\n')
                    continue
                else:
                    parsed_data = {
                        'action_desc': action_desc,
                        'reasoning': editor_response.reasoning,
                        'modified_code': editor_response.modified_code,
                        'validation_error': error_msg,
                    }
                    self.dialogue_saver.save_response_to_file(
                        'editor', parent_id, str(editor_response.model_dump()), parsed_data, False, attempt, child_index
                    )
                    self.logger.write(f'Generated code has syntax error on final async attempt: {error_msg}\n')
                    raise ValueError(
                        f'Editor async failed to generate valid code after {MAX_ATTEMPTS} attempts for child {child_index} of node {parent_id}: {error_msg}'
                    )
            except Exception as e:
                self.logger.write(f'Editor LLM async error on attempt {attempt + 1}: {e}\n')
                self.dialogue_saver.save_failed_response(
                    'editor', parent_id, attempt, str(e), len(prompt_content), child_index
                )
                if attempt == MAX_ATTEMPTS - 1:
                    raise ValueError(
                        f'Editor LLM async failed after {MAX_ATTEMPTS} attempts for child {child_index} of node {parent_id}: {e}'
                    )
                continue
        raise ValueError(
            f'Editor LLM async failed after {MAX_ATTEMPTS} attempts for child {child_index} of node {parent_id}'
        )

    def expand_node(self, node_id: str) -> List[str]:
        if not self.can_expand_node(node_id):
            return []
        self.logger.write(f'Expanding node {node_id} (expansion #{self.state.nodes[node_id].expansion_count + 1})\n')
        return asyncio.run(self._expand_node_parallel(node_id))

    async def _expand_node_parallel(self, node_id: str) -> List[str]:
        node = self.state.nodes[node_id]
        new_children = []
        current_child_count = len(node.children)
        generation_tasks = []
        for i in range(self.expansion_width):
            child_id = MCTSNode.create_child_id(node_id, current_child_count + i)
            task = self._generate_child_code_async(node_id, i, child_id)
            generation_tasks.append(task)
        child_data_list = await asyncio.gather(*generation_tasks)
        for i, (child_id, child_code, action_desc, editor_reasoning) in enumerate(child_data_list):
            adding_order = len(self.state.nodes)
            strategy_name = self.get_strategy_name(i)
            child_node = MCTSNode(
                node_id=child_id,
                parent_id=node_id,
                children=[],
                depth=node.depth + 1,
                sibling_index=current_child_count + len(new_children),
                adding_order=adding_order,
                strategy_name=strategy_name,
                code=child_code,
                action_description=action_desc,
                editor_reasoning=editor_reasoning,
            )
            self.state.nodes[child_id] = child_node
            node.children.append(child_id)
            new_children.append(child_id)
        evaluation_tasks = []
        for child_id in new_children:
            child_node = self.state.nodes[child_id]
            task = self._evaluate_node_with_judge_async(child_id, child_node)
            evaluation_tasks.append(task)
        child_scores = await asyncio.gather(*evaluation_tasks)
        for child_id, score in zip(new_children, child_scores):
            self.state.nodes[child_id].node_value = score
            if score > self.state.metadata.get('best_score', 0.0):
                self.logger.write(f'NEW GLOBAL BEST: {score:.3f} at node {child_id}\n')
            self.save_node_to_file(child_id)
            self.logger.write(f'Created child {child_id} with score {score:.3f}\n')
            self.logger.write(f'Action: {self.state.nodes[child_id].action_description}\n')
            self.log_node_creation(child_id)
        node.expansion_count += 1
        return new_children

    def load_tree_state_metadata(self) -> dict:
        metadata_file = self.output_dir / 'tree_state' / 'tree_state_metadata.yaml'
        if metadata_file.exists():
            try:
                with open(metadata_file, 'r') as f:
                    return yaml.safe_load(f) or {}
            except Exception:
                pass
        return {'iterations': [], 'last_iteration': 0, 'state_files': {}}

    def save_tree_state_metadata(self, metadata: dict):
        metadata_file = self.output_dir / 'tree_state' / 'tree_state_metadata.yaml'
        with open(metadata_file, 'w') as f:
            yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True)

    def save_state(self):
        state_filename = f'tree_state_iter_{self.state.iteration:03d}.yaml'
        state_file = self.output_dir / 'tree_state' / state_filename
        with open(state_file, 'w') as f:
            yaml.dump(self.state.model_dump(), f, default_flow_style=False, allow_unicode=True)
        metadata = self.load_tree_state_metadata()
        if self.state.iteration not in metadata['iterations']:
            metadata['iterations'].append(self.state.iteration)
        metadata['iterations'].sort()
        metadata['last_iteration'] = max(metadata['iterations'])
        metadata['state_files'][self.state.iteration] = state_filename
        self.save_tree_state_metadata(metadata)
        self.logger.write(f'Tree state saved: {state_file.name}\n')

    def run_mcts_optimization(self, iterations: int):
        while self.state.iteration < iterations:
            self.logger.write(f'\n=== MCTS ITERATION {self.state.iteration + 1} ===\n')
            selected_node_id = self.select_node_for_expansion()
            if selected_node_id is None:
                self.logger.write('No more nodes can be expanded. Stopping early.\n')
                break
            new_children = self.expand_node(selected_node_id)
            if not new_children:
                self.logger.write(f'Failed to expand node {selected_node_id}\n')
                continue
            self.state.iteration += 1
            self.compute_metadata()
            self.save_state()
            self.logger.write('Tree status:\n')
            self.logger.write(f'Total nodes: {self.state.metadata.get("total_nodes", 0)}\n')
            self.logger.write(f'Maximum depth: {self.state.metadata.get("max_depth", 0)}\n')
            self.logger.write(f'Expandable nodes: {self.state.metadata.get("expandable_nodes", 0)}\n')
            self.logger.write(
                f'Global best: {self.state.metadata.get("best_score", 0.0):.3f} at {self.state.metadata.get("best_node_id", "0")}\n'
            )
        self.logger.write('\n=== MCTS OPTIMIZATION COMPLETE ===\n')
        self.save_final_summary()

    def save_final_summary(self):
        if not self.output_dir or not self.state.nodes:
            return
        best_path = self.get_path_to_root(self.state.metadata.get('best_node_id', '0'))
        best_path.reverse()
        header = f'MCTS OPTIMIZATION SUMMARY\n{"=" * 40}\nGoal: {self.goal}\nTotal iterations completed: {self.state.iteration}\nTotal nodes created: {len(self.state.nodes)}'
        results = f'FINAL RESULTS:\n  Initial score: {self.state.nodes["0"].node_value:.3f} (root)\n  Best score: {self.state.metadata.get("best_score", 0.0):.3f} (node {self.state.metadata.get("best_node_id", "0")})'
        path_info = f'BEST PATH:\n  {" → ".join(best_path)}\n\nPATH DETAILS:'
        path_details = '\n'.join(
            [
                f'  {node_id}: {self.state.nodes[node_id].node_value:.3f} - {self.state.nodes[node_id].action_description}'
                for node_id in best_path
            ]
        )
        summary = f'{header}\n\n{results}\n\n{path_info}\n{path_details}'
        (self.output_dir / 'run' / 'mcts_final_summary.txt').write_text(summary)
        self.logger.write('Final summary saved to: run/mcts_final_summary.txt\n')

    def log_node_creation(self, node_id: str):
        if node_id not in self.state.nodes:
            self.logger.write(f'Node {node_id} not found\n')
            return
        node = self.state.nodes[node_id]
        basic_info = f'NODE CREATED: {node_id}\n  Parent: {(node.parent_id if node.parent_id else "None (root)")}\n  Depth: {node.depth}\n  Sibling: {node.sibling_index}\n  Action: {node.action_description}'
        if hasattr(node, 'node_value') and node.node_value is not None:
            basic_info += f'\n  Score: {node.node_value:.3f}'
        details = f'\n--- NODE {node_id} DETAILS ---\nChildren: {node.children}'
        details += f'\nExpansion count: {node.expansion_count}\nCan expand: {self.can_expand_node(node_id)}'
        details += '\n--- END NODE DETAILS ---\n\n'
        self.logger.write(basic_info + details)
        self.save_node_to_file(node_id)

    def save_node_to_file(self, node_id: str):
        if not self.output_dir or node_id not in self.state.nodes:
            return
        node = self.state.nodes[node_id]
        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        node_dir = self.output_dir / 'node' / node_id
        node_dir.mkdir(parents=True, exist_ok=True)
        metadata = {
            'node_id': node_id,
            'parent_id': node.parent_id,
            'children': node.children,
            'depth': node.depth,
            'sibling_index': node.sibling_index,
            'adding_order': node.adding_order,
            'strategy_name': node.strategy_name,
            'action_description': node.action_description,
            'node_value': node.node_value,
            'expansion_count': node.expansion_count,
            'timestamp': timestamp,
            'judge_reasoning': node.judge_reasoning if node.judge_reasoning else '',
            'editor_reasoning': node.editor_reasoning if node.editor_reasoning else '',
            'simulation_success': node.simulation_success,
            'simulation_error': node.simulation_error if node.simulation_error else '',
        }
        (node_dir / 'metadata.yaml').write_text(yaml.dump(metadata, default_flow_style=False))
        (node_dir / 'code.py').write_text(node.code)
        if node.csv_output:
            (node_dir / 'simulation_results.csv').write_text(node.csv_output)
        if node.judge_reasoning:
            (node_dir / 'judge_reasoning.txt').write_text(node.judge_reasoning)
        if node.editor_reasoning:
            (node_dir / 'editor_reasoning.txt').write_text(node.editor_reasoning)


class MCTSSimulationRunner:
    def __init__(self, working_dir: Path):
        self.working_dir = working_dir

    def run_simulation_from_code(self, code: str) -> Tuple[bool, str, str]:
        temp_file = self.working_dir / f'temp_mcts_code_{datetime.now().strftime("%H%M%S")}.py'
        try:
            temp_file.write_text(code)
            (_, success, output) = run_code(temp_file, self.working_dir)
            if success:
                return (True, output, '')
            else:
                return (False, '', output)
        finally:
            if temp_file.exists():
                temp_file.unlink()


def main():
    parser = argparse.ArgumentParser(
        description='MCTS-based optimization for Python simulation systems',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog='',
    )
    parser.add_argument('source', help='Source Python simulation file (e.g., 2_2.py)')
    parser.add_argument('goal', help='Optimization goal as a string')
    parser.add_argument('--iterations', type=int, default=10, help='Number of MCTS iterations (default: 10)')
    parser.add_argument('--expansion-width', type=int, default=1, help='Number of children per expansion (default: 1)')
    parser.add_argument(
        '--expand-internal',
        type=int,
        default=0,
        help='Allow expanding internal nodes: 0=leaves only, non-zero=any node (default: 0)',
    )
    parser.add_argument(
        '--max-expansions-per-node', type=int, default=3, help='Maximum times each node can expand (default: 3)'
    )
    parser.add_argument(
        '--csv-sample-ratio', type=float, default=0.5, help='CSV row sampling ratio (0.0-1.0, default: 0.5)'
    )
    parser.add_argument(
        '--csv-add-header-to-value',
        type=int,
        default=0,
        help='Add header names to data values (0=disabled, 1=enabled, default: 0)',
    )
    parser.add_argument(
        '--csv-add-timestamp-to-value',
        type=int,
        default=0,
        help='Add timestamp info to data values (0=disabled, 1=enabled, default: 0)',
    )
    parser.add_argument(
        '--csv-transpose', type=int, default=0, help='Transpose CSV matrix (0=disabled, 1=enabled, default: 0)'
    )
    parser.add_argument(
        '--load-tree-state-from-iter',
        type=int,
        default=0,
        help='Iteration number to load tree state from (0 = start fresh)',
    )
    parser.add_argument('--output-dir', help='Output directory (default: auto-generated with opt_mcts in path)')
    parser.add_argument(
        '--scoring-strategy',
        type=str,
        default='v3',
        help='Scoring strategy for LLM evaluation: v1, v2, v3 (default: v3)',
    )
    args = parser.parse_args()
    source_file = Path(args.source)
    if args.output_dir:
        output_dir = Path(args.output_dir)
    else:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        parent_name = source_file.parent.name if source_file.parent.name != '.' else ''
        filename = source_file.name
        if parent_name:
            combined_source = f'{parent_name}_{filename}'
        else:
            combined_source = filename
        sanitized_source = sanitize_name(combined_source)
        sanitized_goal = sanitize_name(args.goal)
        output_dir = Path(f'output/opt_mcts/{timestamp}-{sanitized_source}-{sanitized_goal}')
    output_dir.mkdir(parents=True, exist_ok=True)
    run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_dir = output_dir / 'run' / run_timestamp
    tree_state_dir = output_dir / 'tree_state'
    run_dir.mkdir(parents=True, exist_ok=True)
    tree_state_dir.mkdir(parents=True, exist_ok=True)
    cmdline_str = ' '.join((shlex.quote(arg) for arg in sys.argv))
    with open(run_dir / 'cmdline.txt', 'w') as f:
        f.write(f'Generated at: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
        f.write(f'{cmdline_str}\n')
    args_dict = vars(args)
    with open(run_dir / 'args.yaml', 'w') as f:
        f.write(f'Generated at: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
        yaml.dump(args_dict, f, default_flow_style=False, allow_unicode=True)
    logger = setup_logging(run_dir)
    original_stdout = sys.stdout
    sys.stdout = logger
    try:
        enhanced_goal = preprocess_goal(
            args.goal,
            logger,
            csv_add_header_to_value=bool(args.csv_add_header_to_value),
            csv_add_timestamp_to_value=bool(args.csv_add_timestamp_to_value),
            csv_transpose=bool(args.csv_transpose),
        )
        print('=== MCTS OPTIMIZATION START ===')
        print(f'Source: {source_file}')
        print(f'Goal: {args.goal}')
        if enhanced_goal != args.goal:
            print(f'Enhanced goal with file contents: {len(enhanced_goal)} characters')
        print(f'Iterations: {args.iterations}')
        print(f'Expansion width: {args.expansion_width}')
        print(f'Expand internal: {args.expand_internal}')
        print(f'Max expansions per node: {args.max_expansions_per_node}')
        print(f'CSV sample ratio: {args.csv_sample_ratio}')
        print(f'Output: {output_dir}')
        print(f'Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
        print()
        try:
            claude = ClaudeLLM()
        except ValueError as e:
            print(f'ERROR: Failed to initialize Claude LLM: {e}')
            sys.exit(1)
        runner = MCTSSimulationRunner(Path.cwd())
        optimizer = MCTSOptimizer(
            claude_llm=claude,
            runner=runner,
            goal=enhanced_goal,
            logger=logger,
            output_dir=output_dir,
            run_timestamp=run_timestamp,
            expansion_width=args.expansion_width,
            expand_internal=args.expand_internal,
            max_expansions_per_node=args.max_expansions_per_node,
            csv_sample_ratio=args.csv_sample_ratio,
            csv_add_header_to_value=bool(args.csv_add_header_to_value),
            csv_add_timestamp_to_value=bool(args.csv_add_timestamp_to_value),
            csv_transpose=bool(args.csv_transpose),
            scoring_strategy=args.scoring_strategy,
        )
        with open(run_dir / 'goal.txt', 'w') as f:
            f.write(args.goal)
        with open(run_dir / 'enhanced_goal.txt', 'w') as f:
            f.write(enhanced_goal)
        shutil.copy2(source_file, run_dir / 'original_code.py')
        if args.load_tree_state_from_iter > 0:
            optimizer.load_tree_state(args.load_tree_state_from_iter)
        else:
            initial_code = source_file.read_text()
            optimizer.initialize_root(initial_code)
        optimizer.run_mcts_optimization(args.iterations)
        print(
            f'Best score achieved: {optimizer.state.metadata.get("best_score", 0.0):.3f} at node {optimizer.state.metadata.get("best_node_id", "0")}'
        )
    except KeyboardInterrupt:
        print('\n=== MCTS OPTIMIZATION INTERRUPTED ===')
    except Exception as e:
        print('\n=== MCTS OPTIMIZATION FAILED ===')
        print(f'ERROR: {e}')
        traceback.print_exc()
    finally:
        logger.write(f'Results saved to: {output_dir}\n')
        sys.stdout = original_stdout
        logger.close()


if __name__ == '__main__':
    main()
