"""
RLBench Loop - AST  step-by-step  failure-only  

  :
1. step : run_skill   statement   scene  JSON 
2. failure : run_skill   failure   JSON 

Usage:
    python rlbench_loop.py --task BasketballInHoop --mode step --source_robot panda --target_robot ur5
    python rlbench_loop.py --task BasketballInHoop --mode failure --source_robot panda --target_robot ur5
    python rlbench_loop.py --task BasketballInHoop --mode step --max_steps 5  #  5 skill  
"""
import re
import argparse
import importlib
import ast
import astor
import sys
import json
import traceback
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
import os
from concurrent.futures import ThreadPoolExecutor, Future
import threading

from env import setup_environment, shutdown_environment
from utils.trigger_condition import SkillFailure, PathOutOfWorkspace
from rlbench.backend.exceptions import InvalidActionError
from utils.helper import object_names as get_object_names, to_camel_case, list_task_objects

# Prompt construction module
from policy_provider import PolicyProvider
from prompts.prompt_utils import get_prompt, get_prompt_for_codex

from rlbench.gym import RLBenchEnv

from pyrep.objects.shape import Shape
from pyrep.objects.joint import Joint
from pyrep.objects.dummy import Dummy
from pyrep.objects.proximity_sensor import ProximitySensor
from ours_utils.skill_conditions import check_preconditions, check_postconditions
from ours_utils.validate_conditions import (
    ProjectedStateTracker,
    InvalidStatementInfo,
    create_invalid_statements_payload,
    SymbolicEffectTable,
)
from grasp_info import get_graspable_objects_info, is_graspable_object, check_gripper_alignment_for_scene


def patched_close(self) -> None:
    pass

RLBenchEnv.close = patched_close

# ============================================================================
# JSON Output Helpers
# ============================================================================

def numpy_to_list(obj):
    """numpy    (JSON )"""
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (list, tuple)):
        return [numpy_to_list(item) for item in obj]
    elif isinstance(obj, dict):
        return {k: numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    return obj

from headers import RUN_SKILL_INTERFACE, HEADER

def _strip_code_blocks(code: str) -> str:
    """  (```python, ```) """
    # ```python  ```    
    code = re.sub(r'^```(?:python)?\s*\n?', '', code, flags=re.MULTILINE)
    # ```    
    code = re.sub(r'\n?```\s*$', '', code, flags=re.MULTILINE)
    #   ```  
    code = re.sub(r'\n```\s*\n', '\n', code)
    return code


def _fix_orientation_to_quaternion(code: str) -> str:
    """reference_quat get_orientation()   get_quaternion() """
    # reference_quat = xxx.get_orientation() -> reference_quat = xxx.get_quaternion()
    code = re.sub(
        r'(reference_quat\s*=\s*\w+)\.get_orientation\(\)',
        r'\1.get_quaternion()',
        code
    )
    return code


def _wrap_with_run_skill(code: str, robot_name: str) -> str:
    """run_skill       

    LLM  body   (def run_skill    ):
    )
        target_pos = [0.0, 0.0, 0.3]
        ball = Shape('ball')
        ...
    ->
        def run_skill(env, task):
            target_pos = [0.0, 0.0, 0.3]
            ball = Shape('ball')
            ...
    """
    #  run_skill    
    if 'def run_skill' in code:
        return code

    #    
    stripped = code.strip()
    if not stripped:
        return code

    #      
    lines = code.splitlines()
    first_non_empty = None
    for line in lines:
        if line.strip():
            first_non_empty = line
            break

    if first_non_empty is None:
        return code

    # import/from    body  (   )
    if first_non_empty.lstrip().startswith(('import ', 'from ')):
        return code

    # run_skill  
    header = HEADER.get(robot_name, HEADER["panda"])
    return header + "\n\n" + RUN_SKILL_INTERFACE + code


def _sanitize_code(code: str) -> str:
    """ :     orientation """
    code = _strip_code_blocks(code)
    code = _fix_orientation_to_quaternion(code)
    return code


def postprocess_initial_code_codex(llm_generated: str, robot_name: str) -> str:
    # run_skill  body 
    func_name = "run_skill"
    header = HEADER.get(robot_name, HEADER["panda"])

    if f"def {func_name}" in llm_generated:
        lines = llm_generated.splitlines()

        for i, line in enumerate(lines):
            if line.lstrip().startswith(f"def {func_name}"):
                body = "\n".join(lines[i + 1:])
        result = header + "\n\n" + RUN_SKILL_INTERFACE + body
        return _sanitize_code(result)
    else:
        # run_skill  LLM  body  header 
        body = llm_generated.strip()
        # code fence 
        if body.startswith("```"):
            body = body.split("```", 2)[1]
            if body.startswith("python"):
                body = body[6:]
            body = body.strip()
        if body.endswith("```"):
            body = body.rsplit("```", 1)[0].strip()

        # indent  (4 spaces)
        body_lines = body.splitlines()
        if body_lines:
            first_line = body_lines[0]
            current_indent = len(first_line) - len(first_line.lstrip(" "))
            if current_indent != 4:
                body_lines = [" " * 4 + line.lstrip(" ") if line.strip() else line for line in body_lines]
            body = "\n".join(body_lines)

        result = header + "\n\n" + RUN_SKILL_INTERFACE + body
        return _sanitize_code(result)

def postprocess_initial_code(llm_generated: str, robot_name: str) -> str:
    code_fence = "`"

    # 1) code fence  
    if code_fence not in llm_generated:
        body = llm_generated.split(code_fence, 1)[0]
        body = body.strip("\n")

        # 2)   indent  ( 4 spaces)
        body_lines = body.splitlines()
        if body_lines:
            first_line = body_lines[0]
            current_indent = len(first_line) - len(first_line.lstrip(" "))

            if current_indent != 4:
                stripped = first_line.lstrip(" ")
                body_lines[0] = " " * 4 + stripped

            body = "\n".join(body_lines)

        header = HEADER.get(robot_name, HEADER["panda"])
        result = header + "\n\n" + RUN_SKILL_INTERFACE + body
        return _sanitize_code(result)

    # 2) [INITIAL_CODE_START]/[INITIAL_CODE_END]  
    else:
        start_marker = "[INITIAL_CODE_START]"
        end_marker = "[INITIAL_CODE_END]"

        if start_marker in llm_generated and end_marker in llm_generated:
            start_idx = llm_generated.index(start_marker) + len(start_marker)
            end_idx = llm_generated.index(end_marker)
            body = llm_generated[start_idx:end_idx].strip("\n")

            #   indent  ( 4 spaces)
            body_lines = body.splitlines()
            if body_lines:
                first_line = body_lines[0]
                current_indent = len(first_line) - len(first_line.lstrip(" "))

                if current_indent != 4:
                    stripped = first_line.lstrip(" ")
                    body_lines[0] = " " * 4 + stripped

                body = "\n".join(body_lines)

            header = HEADER.get(robot_name, HEADER["panda"])
            result = header + "\n\n" + RUN_SKILL_INTERFACE + body
            return _sanitize_code(result)
        else:
            #     
            body = llm_generated.strip("\n")

            #   indent 
            body_lines = body.splitlines()
            if body_lines:
                first_line = body_lines[0]
                current_indent = len(first_line) - len(first_line.lstrip(" "))

                if current_indent != 4:
                    stripped = first_line.lstrip(" ")
                    body_lines[0] = " " * 4 + stripped

                body = "\n".join(body_lines)

            header = HEADER.get(robot_name, HEADER["panda"])
            result = header + "\n\n" + RUN_SKILL_INTERFACE + body
            return _sanitize_code(result)



def get_scene_info(task, task_name: str = "", robot_type: str = "panda", include_grasp_guidance: bool = True) -> Dict[str, Any]:
    """ scene object  

    Args:
        task: RLBench task environment
        task_name: Name of the current task
        robot_type: Type of robot (panda, ur5, sawyer, jaco)
        include_grasp_guidance: Whether to include grasp guidance for graspable objects

    Returns:
        Dictionary containing objects info, gripper info, and optionally grasp guidance
    """
    _type_registry = {'Shape': Shape, 'Joint': Joint, 'Dummy': Dummy, 'ProximitySensor': ProximitySensor}

    objects_info = []
    for o in list_task_objects(task):
        name = o.get_name()
        type_str = str(o.get_type())
        type_name = to_camel_case(type_str.split(".")[-1])

        # :   
        if any(sub in name for sub in ('waypoint', 'visual', 'spawn', 'wrap', 'topPlate', 'boundary', 'base', 'dynamic')):
            if name != 'dirt_boundary':
                continue

        try:
            cls = _type_registry.get(type_name)
            if cls:
                obj = cls(name)
                pos = obj.get_position()
                quat = obj.get_quaternion()

                # Bounding box  (Shape  )
                bbox = None
                if type_name == "Shape":
                    try:
                        bb = obj.get_bounding_box()  # [min_x, max_x, min_y, max_y, min_z, max_z]
                        bbox = {
                            "min": [bb[0], bb[2], bb[4]],  # [min_x, min_y, min_z]
                            "max": [bb[1], bb[3], bb[5]],  # [max_x, max_y, max_z]
                            "size": [bb[1] - bb[0], bb[3] - bb[2], bb[5] - bb[4]]  # [width, depth, height]
                        }
                    except Exception:
                        pass

                obj_info = {
                    "name": name,
                    "type": type_name,
                    "position": numpy_to_list(pos),
                    "quaternion": numpy_to_list(quat),
                }
                if bbox is not None:
                    obj_info["bounding_box"] = bbox
                objects_info.append(obj_info)
        except Exception:
            objects_info.append({
                "name": name,
                "type": type_name,
                "position": None,
                "quaternion": None,
            })

    # gripper 
    obs = task.get_observation()
    gripper_info = {
        "position": numpy_to_list(obs.gripper_pose[:3]),
        "quaternion": numpy_to_list(obs.gripper_pose[3:7]),
        "is_open": bool(obs.gripper_open > 0.5)
    }

    result = {
        "objects": objects_info,
        "gripper": gripper_info
    }

    # Grasp guidance  (graspable objects   )
    if include_grasp_guidance:
        try:
            grasp_infos, grasp_guidance_text = get_graspable_objects_info(
                objects_info, task_name, robot_type
            )
            result["grasp_guidance"] = grasp_guidance_text
            result["graspable_objects"] = grasp_infos
        except Exception as e:
            # Grasp guidance      
            result["grasp_guidance"] = ""
            result["graspable_objects"] = []

    return result


def get_initial_info(task, task_name: str, robot_type: str, descriptions: List[str]) -> Dict[str, Any]:
    """   JSON  """
    scene_info = get_scene_info(task, task_name, robot_type)

    # task description
    descriptions_list = [s[0].upper() + s[1:] if s else '' for s in descriptions]
    task_description = ". ".join(descriptions_list) + "." if descriptions_list else ""

    # primitive skill 
    # skill_info = PRIMITIVE_SKILLS.get(robot_type.lower(), PRIMITIVE_SKILLS["panda"])

    return {
        "type": "initial",
        "task_name": task_name,
        "task_description": task_description,
        "robot_type": robot_type,
        "scene": scene_info
    }

def output_json(data: Dict[str, Any]):
    """JSON  (stdout)"""
    serializable_data = numpy_to_list(data)
    print(f"[JSON_START]\n{json.dumps(serializable_data, indent=2, ensure_ascii=False)}\n[JSON_END]")


# ============================================================================
# AST Parser for run_skill function
# ============================================================================

class RunSkillExtractor(ast.NodeVisitor):
    """run_skill  body  AST visitor"""

    def __init__(self):
        self.run_skill_body: List[ast.stmt] = []
        self.run_skill_args: List[str] = []
        self.helper_functions: Dict[str, ast.FunctionDef] = {}
        self.imports: List[ast.stmt] = []

    def visit_Import(self, node):
        self.imports.append(node)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        self.imports.append(node)
        self.generic_visit(node)

    def visit_FunctionDef(self, node):
        if node.name == "run_skill":
            self.run_skill_body = node.body
            self.run_skill_args = [arg.arg for arg in node.args.args]
        else:
            self.helper_functions[node.name] = node
        self.generic_visit(node)


def get_statement_code(stmt: ast.stmt, source_lines: List[str]) -> str:
    """AST statement    """
    start_line = stmt.lineno - 1
    end_line = getattr(stmt, 'end_lineno', stmt.lineno)
    return '\n'.join(source_lines[start_line:end_line])


def get_statements_to_execute_by_text_matching(
    run_skill_body: List[ast.stmt],
    executed_stmt_texts: List[str],
    source_lines: List[str],
    use_text_matching: bool = True
) -> Tuple[List[ast.stmt], int]:
    """
    Repair     statement .

    Args:
        run_skill_body:   run_skill  body (AST statements)
        executed_stmt_texts:   statement 
        source_lines:    
        use_text_matching: True   , False   

    Returns:
        ( statements ,  )
    """
    if not use_text_matching:
        #  :  
        next_index = len(executed_stmt_texts)
        return run_skill_body[next_index:], next_index

    #   :   statement ,    
    executed_set = set(executed_stmt_texts)
    statements_to_execute = []
    skip_count = 0
    found_unexecuted = False

    for stmt in run_skill_body:
        stmt_code = get_statement_code(stmt, source_lines)

        if not found_unexecuted and stmt_code in executed_set:
            #   statement - 
            skip_count += 1
            continue
        else:
            #   statement  -   statement 
            found_unexecuted = True
            statements_to_execute.append(stmt)

    #       (LLM def run_skill   )
    #     
    if skip_count == 0 and len(executed_stmt_texts) > 0 and len(statements_to_execute) == len(run_skill_body):
        print(f"[TEXT_MATCHING] No text match found. Appending new statements after {len(executed_stmt_texts)} executed statements.")
        return run_skill_body, len(executed_stmt_texts)

    #  statement  skip  (   )
    if len(statements_to_execute) == 0 and skip_count > 0:
        print(f"[TEXT_MATCHING] All statements matched. No new statements to execute.")
        return [], skip_count

    return statements_to_execute, skip_count


def is_skill_call(stmt: ast.stmt) -> Tuple[bool, Optional[str]]:
    """statement skill  , skill  """
    skill_names = {'pick', 'place', 'move', 'push', 'open_gripper', 'close_gripper',
                   'align_two_axes', 'align_to_quaternion', 'normalize_quaternion',
                   'angle_diff',
                   # ur5 skills
                   'ur5_move_to', 'ur5_grasp_at', 'ur5_release_at', 'ur5_align_gripper',
                   'close_ur5_ee', 'open_ur5_ee',
                   # sawyer skills
                   'sawyer_move_to', 'sawyer_align_gripper', 'sawyer_open_gripper',
                   'sawyer_close_gripper', 'sawyer_pick', 'sawyer_place'}

    def find_call(node):
        if isinstance(node, ast.Call):
            if isinstance(node.func, ast.Name) and node.func.id in skill_names:
                return node.func.id
            elif isinstance(node.func, ast.Attribute) and node.func.attr in skill_names:
                return node.func.attr
        return None

    # Expr statement (  )
    if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
        skill = find_call(stmt.value)
        if skill:
            return True, skill

    # Assign statement ( )
    if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
        skill = find_call(stmt.value)
        if skill:
            return True, skill

    # Tuple unpacking assignment
    if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
        skill = find_call(stmt.value)
        if skill:
            return True, skill

    return False, None


def contains_return(stmt: ast.stmt) -> bool:
    """statement  return     """
    if isinstance(stmt, ast.Return):
        return True
    for child in ast.walk(stmt):
        if isinstance(child, ast.Return):
            return True
    return False


def remove_returns_from_stmt(stmt: ast.stmt) -> ast.stmt:
    """statement  return  pass """
    class ReturnRemover(ast.NodeTransformer):
        def visit_Return(self, node):
            # return  pass 
            return ast.Pass()

    transformer = ReturnRemover()
    return transformer.visit(stmt)


# ============================================================================
# Execution Context
# ============================================================================

def create_execution_context(env, task, robot_type: str) -> Dict[str, Any]:
    """   (import  global )"""
    #    skill_code  
    if robot_type.lower() == "ur5":
        skill_module_name = "skill_code_ur5"
    elif robot_type.lower() == "sawyer":
        skill_module_name = "skill_code_sawyer"
    elif robot_type.lower() == "jaco":
        skill_module_name = "skill_code_jaco"
    else:
        skill_module_name = "skill_code"

    try:
        skill_module = importlib.import_module(skill_module_name)
    except ImportError:
        # fallback to default
        skill_module = importlib.import_module("skill_code")

    # env_utils    import
    import env_utils
    from scipy.spatial.transform import Rotation as R

    context = {
        'env': env,
        'task': task,
        'np': np,
        'Shape': Shape,
        'Joint': Joint,
        'Dummy': Dummy,
        'ProximitySensor': ProximitySensor,
        # env_utils functions
        'get_bbox_sizes': env_utils.get_bbox_sizes,
        'quat_mul': env_utils.quat_mul,
        'normalize_vector': env_utils.normalize_vector,
        # skill functions (panda-style names)
        'pick': getattr(skill_module, 'pick', None),
        'place': getattr(skill_module, 'place', None),
        'move': getattr(skill_module, 'move', None),
        'push': getattr(skill_module, 'push', None),
        'open_gripper': getattr(skill_module, 'open_gripper', None),
        'close_gripper': getattr(skill_module, 'close_gripper', None),
        'align_two_axes': getattr(skill_module, 'align_two_axes', None),
        'align_to_quaternion': getattr(skill_module, 'align_to_quaternion', None),
        'normalize_quaternion': getattr(skill_module, 'normalize_quaternion', None),
        'angle_diff': getattr(skill_module, 'angle_diff', None),
        # ur5 skill functions (ur5-style names)
        'ur5_move_to': getattr(skill_module, 'ur5_move_to', None),
        'ur5_grasp_at': getattr(skill_module, 'ur5_grasp_at', None),
        'ur5_release_at': getattr(skill_module, 'ur5_release_at', None),
        'ur5_align_gripper': getattr(skill_module, 'ur5_align_gripper', None),
        'close_ur5_ee': getattr(skill_module, 'close_ur5_ee', None),
        'open_ur5_ee': getattr(skill_module, 'open_ur5_ee', None),
        # sawyer skill functions
        'sawyer_move_to': getattr(skill_module, 'sawyer_move_to', None),
        'sawyer_align_gripper': getattr(skill_module, 'sawyer_align_gripper', None),
        'sawyer_open_gripper': getattr(skill_module, 'sawyer_open_gripper', None),
        'sawyer_close_gripper': getattr(skill_module, 'sawyer_close_gripper', None),
        'sawyer_pick': getattr(skill_module, 'sawyer_pick', None),
        'sawyer_place': getattr(skill_module, 'sawyer_place', None),
    }

    # math, time  
    import math
    import time
    context['math'] = math
    context['time'] = time
    context['sleep'] = time.sleep
    context['R'] = R  # scipy.spatial.transform.Rotation

    return context

# ============================================================================
# Ours Mode Executor (spec-driven validation, no waypoint-based checks)
# ============================================================================

def extract_skill_call_info(stmt: ast.stmt) -> Tuple[Optional[str], Optional[ast.Call]]:
    """Extract skill name and call node from a statement."""
    call_node = None

    if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
        call_node = stmt.value
    elif isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
        call_node = stmt.value

    if call_node is None:
        return None, None

    if isinstance(call_node.func, ast.Name):
        return call_node.func.id, call_node
    elif isinstance(call_node.func, ast.Attribute):
        return call_node.func.attr, call_node

    return None, None


def execute_ours_mode(
    env, task, task_name: str, robot_type: str,
    source_code: str, source_lines: List[str],
    max_skill_calls: Optional[int] = None,
    provider: Optional["PolicyProvider"] = None,
    max_repair_attempts: int = 5,
    use_text_matching: bool = True
) -> Optional[str]:
    """
    Execute policy in 'ours' mode with background validation of future statements.

    Key features:
    1. Tracks executed/current/future statements
    2. Stores post-execution state from observations
    3. Runs background validation thread during skill execution
    4. Validates future statements using projected state (confirmed + predicted effects)
    5. Sends batch repair requests for invalid statements

    Args:
        use_text_matching: If True, use text-based matching to skip executed statements
                          after repair. If False, use index-based matching (legacy).

    Returns:
        Final code after all repairs and revisions (or None on early error)
    """
    current_code = source_code
    current_lines = source_lines
    executed_stmt_texts: List[str] = []
    skill_call_count = 0
    repair_attempt_count = 0  # Track repair attempts to prevent infinite loops

    context = create_execution_context(env, task, robot_type)

    # Add run_skill function arguments to context (so they can be referenced in code)
    initial_obs = task.get_observation()
    context['descriptions'] = None  # Will be set by actual run_skill call if needed
    context['obs'] = initial_obs
    context['variations_index'] = 0  # Default value

    # Initialize projected state tracker for symbolic validation
    state_tracker = ProjectedStateTracker()

    # Initialize state from current observation and scene
    initial_scene = get_scene_info(task, task_name, robot_type)
    state_tracker.initialize_from_observation(initial_obs, initial_scene.get("objects", []))

    # Pre-process: Fix object type mismatches before execution
    # This prevents WrongObjectTypeError (e.g., Shape('success') when it's ProximitySensor)
    try:
        from utils.code_postprocess import fix_object_types_from_scene, remove_unnecessary_success_checks
        current_code = fix_object_types_from_scene(current_code, initial_scene)
        current_code = remove_unnecessary_success_checks(current_code)
        current_lines = current_code.split('\n')
        print("[POSTPROCESS] Applied object type auto-correction and removed unnecessary success checks")
    except Exception as e:
        print(f"[POSTPROCESS] Warning: Code postprocessing failed: {e}")

    # Thread synchronization for background validation
    validation_lock = threading.Lock()
    pending_repair_code: Optional[str] = None
    validation_thread: Optional[threading.Thread] = None

    def apply_pending_repair():
        """Check and apply any pending code repair from background validation."""
        nonlocal current_code, current_lines, pending_repair_code
        with validation_lock:
            if pending_repair_code != None:
                print(f"[VALIDATION_REPAIR] Applying repaired code from background validation")
                current_code = pending_repair_code
                current_lines = pending_repair_code.split('\n')
                pending_repair_code = None
                # Reset projection after code change
                state_tracker.reset_projection()
                return True
        return False

    def eval_call_arguments(call_node, eval_context):
        args = []
        kwargs = {}

        for arg in call_node.args:
            args.append(eval(astor.to_source(arg), eval_context))

        for kw in call_node.keywords:
            kwargs[kw.arg] = eval(astor.to_source(kw.value), eval_context)

        return args, kwargs

    def extract_target_pos(args, kwargs):
        """Heuristic to extract target position argument from evaluated args/kwargs."""
        if "target_pos" in kwargs:
            return kwargs["target_pos"]
        if "grasp_pos" in kwargs:
            return kwargs["grasp_pos"]
        if "place_pos" in kwargs:
            return kwargs["place_pos"]
        # Skip env, task if present
        effective = args[2:] if len(args) > 2 else args
        if effective:
            return effective[0]
        return None

    def extract_target_obj(kwargs, args=None):
        """Heuristic to extract target object identifier from kwargs or first arg."""
        if "obj" in kwargs:
            return kwargs["obj"]
        if "object" in kwargs:
            return kwargs["object"]
        if "target_obj" in kwargs:
            return kwargs["target_obj"]
        return None

    def background_validate_future_statements(
        future_stmts: List[ast.stmt],
        src_lines: List[str],
        ctx: Dict[str, Any],
        prov,
        code: str,
        exec_stmts: List[str],
        start_idx: int,
        tracker: ProjectedStateTracker,
        current_stmt_info: Optional[Dict[str, Any]] = None,
        pre_captured_scene_info: Optional[Dict[str, Any]] = None,
    ):
        """
        Background thread function to validate future statements using projected state.

        Validates each future primitive-skill statement by:
        1. First applying current statement's effect (assuming success)
        2. For each future statement:
           a) Check preconditions on projected state
           b) Apply symbolic effect
           c) Check postconditions on updated projected state
        3. Collect all violations (pre + post) for batch repair

        Args:
            current_stmt_info: Dict with current statement's skill_name, target_pos, target_obj
                               If provided and is a primitive skill, its effect is applied first.
            pre_captured_scene_info: Scene info captured from main thread BEFORE starting background
                                     validation. This avoids CoppeliaSim thread-safety issues.
        """
        nonlocal pending_repair_code

        try:
            # Reset projection to confirmed state before validating
            tracker.reset_projection()

            primitive_skills = (
                SymbolicEffectTable.GRASP_SKILLS |
                SymbolicEffectTable.RELEASE_SKILLS |
                SymbolicEffectTable.MOVE_SKILLS |
                SymbolicEffectTable.PUSH_SKILLS |
                SymbolicEffectTable.ALIGN_SKILLS
            )

            # ===== STEP 1: Apply current statement's effect first =====
            # Since validation runs while current statement executes,
            # we assume it will succeed and apply its effect to projected state
            if current_stmt_info != None:
                curr_skill = current_stmt_info.get("skill_name")
                if curr_skill and curr_skill.lower() in primitive_skills:
                    curr_target_pos = current_stmt_info.get("target_pos")
                    curr_target_obj = current_stmt_info.get("target_obj")
                    tracker.apply_projection(curr_skill, curr_target_pos, curr_target_obj)
                    print(f"[BG_VALIDATE] Applied current statement effect: {curr_skill} "
                          f"-> gripper_open={tracker.projected_state.gripper.is_open}, "
                          f"held_obj={tracker.projected_state.gripper.held_object}")

            # ===== STEP 2: Validate each future statement =====
            invalid_statements: List[Dict[str, Any]] = []

            for offset, stmt in enumerate(future_stmts):
                skill_name, call_node = extract_skill_call_info(stmt)

                # For non-skill statements (variable assignments, etc.),
                # execute them in context so future statements can reference the variables
                if skill_name is None:
                    try:
                        stmt_module = ast.Module(body=[stmt], type_ignores=[])
                        ast.fix_missing_locations(stmt_module)
                        compiled = compile(stmt_module, '<string>', 'exec')
                        exec(compiled, ctx)
                    except Exception:
                        # Ignore errors - some statements may depend on runtime state
                        pass
                    continue

                skill_lower = skill_name.lower()
                if skill_lower not in primitive_skills:
                    # Non-primitive skill calls (like normalize_quaternion) - execute for side effects
                    try:
                        stmt_module = ast.Module(body=[stmt], type_ignores=[])
                        ast.fix_missing_locations(stmt_module)
                        compiled = compile(stmt_module, '<string>', 'exec')
                        exec(compiled, ctx)
                    except Exception:
                        pass
                    continue

                stmt_code = get_statement_code(stmt, src_lines)
                stmt_lineno = stmt.lineno
                stmt_index = start_idx + offset

                # Try to evaluate arguments
                target_pos = None
                target_obj = None

                if call_node != None:
                    try:
                        args_eval, kwargs_eval = eval_call_arguments(call_node, ctx)
                        target_pos = extract_target_pos(args_eval, kwargs_eval)
                        target_obj = extract_target_obj(kwargs_eval, args_eval)
                    except NameError:
                        # Variable not yet defined in context - skip silently
                        pass
                    except Exception as e:
                        print(f"[BG_VALIDATE] Cannot evaluate args for stmt {stmt_index}: {e}")

                # Full validation: precondition -> effect -> postcondition
                # This method applies the effect internally and checks both conditions
                validation_result = tracker.validate_statement_full(
                    skill_name,
                    target_pos=target_pos,
                    target_obj=target_obj,
                )

                if not validation_result["success"]:
                    violation_type = []
                    if not validation_result.get("precondition_success", True):
                        violation_type.append("precondition")
                    if not validation_result.get("postcondition_success", True):
                        violation_type.append("postcondition")

                    print(f"[BG_VALIDATE] Statement {stmt_index} ({skill_name}) INVALID "
                          f"[{', '.join(violation_type)}]: {validation_result['violations']}")

                    invalid_statements.append({
                        "line_number": stmt_lineno,
                        "statement_code": stmt_code,
                        "skill_name": skill_name,
                        "violations": validation_result["violations"],
                        "precondition_violations": validation_result.get("precondition_violations", []),
                        "postcondition_violations": validation_result.get("postcondition_violations", []),
                        "warnings": validation_result["warnings"],
                        "projected_state": {
                            "gripper_pos": validation_result.get("projected_gripper_pos"),
                            "gripper_open": validation_result.get("projected_gripper_open"),
                            "held_object": validation_result.get("projected_held_object"),
                        },
                    })
                else:
                    if validation_result["warnings"]:
                        print(f"[BG_VALIDATE] Statement {stmt_index} ({skill_name}) warnings: {validation_result['warnings']}")

                # Note: validate_statement_full already applies the effect,
                # so we don't need to call apply_projection again

            # ===== STEP 3: Batch repair if there are invalid statements =====
            if invalid_statements and prov != None:
                print(f"[BG_VALIDATE] Found {len(invalid_statements)} invalid statements, requesting batch repair")

                # Use pre-captured scene info (captured in main thread before background validation started)
                # This avoids CoppeliaSim thread-safety issues (signal 11 segfault)
                if pre_captured_scene_info is None:
                    print(f"[BG_VALIDATE] WARNING: No pre-captured scene_info available, skipping batch repair")
                    output_json({
                        "type": "validation_warning",
                        "invalid_statements": invalid_statements,
                        "repair_attempted": False,
                        "repair_success": False,
                        "reason": "scene_info not available in background thread",
                    })
                    return

                new_code = prov.batch_invalid_repair(
                    current_code=code,
                    executed_stmt_texts=exec_stmts,
                    invalid_statements=invalid_statements,
                    scene=pre_captured_scene_info,
                )

                if new_code:
                    # Apply post-processing to repaired code
                    from utils.code_postprocess import apply_code_postprocessing
                    new_code = apply_code_postprocessing(new_code, pre_captured_scene_info)
                    with validation_lock:
                        pending_repair_code = new_code
                    print(f"[BG_VALIDATE] Batch repair successful, code will be applied after current execution")
                else:
                    # Log the invalid statements even if repair fails
                    output_json({
                        "type": "validation_warning",
                        "invalid_statements": invalid_statements,
                        "repair_attempted": True,
                        "repair_success": False,
                    })

        except Exception as e:
            print(f"[BG_VALIDATE] Error during background validation: {e}")
            traceback.print_exc()

    last_valid_code = source_code  #    
    last_valid_lines = source_lines
    last_valid_executed_stmts: List[str] = []

    try:
        while True:
            # Check if there's pending repair from previous background validation
            if apply_pending_repair():
                continue  # Restart loop with repaired code

            #   (  , orientation )
            current_code = _sanitize_code(current_code)
            current_lines = current_code.split('\n')

            try:
                tree = ast.parse(current_code)
            except SyntaxError as e:
                print(f"[ERROR] SyntaxError in code: {e}")
                repair_attempt_count += 1
                if repair_attempt_count >= max_repair_attempts:
                    print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                    return last_valid_code
                print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
                current_code = last_valid_code
                current_lines = last_valid_lines
                executed_stmt_texts = list(last_valid_executed_stmts)
                continue

            extractor = RunSkillExtractor()
            extractor.visit(tree)

            if not extractor.run_skill_body:
                repair_attempt_count += 1
                print(f"[ERROR] run_skill    . ( {repair_attempt_count}/{max_repair_attempts})")
                if repair_attempt_count >= max_repair_attempts:
                    print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                    return last_valid_code
                print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
                current_code = last_valid_code
                current_lines = last_valid_lines
                executed_stmt_texts = list(last_valid_executed_stmts)
                continue

            #   -    
            last_valid_code = current_code
            last_valid_lines = current_lines
            last_valid_executed_stmts = list(executed_stmt_texts)

            for _, func_node in extractor.helper_functions.items():
                func_code = compile(ast.Module(body=[func_node], type_ignores=[]), '<string>', 'exec')
                exec(func_code, context)

            # Execute top-level imports into context (for background validation)
            for import_node in extractor.imports:
                try:
                    import_code = compile(ast.Module(body=[import_node], type_ignores=[]), '<string>', 'exec')
                    exec(import_code, context)
                except Exception as e:
                    print(f"[WARN] Failed to exec import: {e}")

            statements_to_execute, next_index = get_statements_to_execute_by_text_matching(
                extractor.run_skill_body, executed_stmt_texts, current_lines, use_text_matching
            )

            if not statements_to_execute:
                success_check, _ = task._task.success()
                scene_info = get_scene_info(task, task_name, robot_type)

                if success_check:
                    output_json({
                        "type": "complete",
                        "total_steps": len(executed_stmt_texts),
                        "total_skill_calls": skill_call_count,
                        "success": True,
                        "scene": scene_info
                    })
                    return current_code

                if max_skill_calls != None and skill_call_count >= max_skill_calls:
                    print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                    output_json({
                        "type": "max_steps_reached",
                        "total_steps": len(executed_stmt_texts),
                        "skill_call_count": skill_call_count,
                        "max_skill_calls": max_skill_calls,
                        "success": False,
                        "scene": scene_info
                    })
                    return current_code

                print(f"[INFO] All statements executed but task incomplete. Requesting additional code...")
                if provider != None:
                    repair_attempt_count += 1
                    if repair_attempt_count > max_repair_attempts:
                        print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                        output_json({
                            "type": "max_repairs_reached",
                            "step_index": len(executed_stmt_texts),
                            "repair_attempt_count": repair_attempt_count,
                            "max_repair_attempts": max_repair_attempts,
                            "last_error": "Task incomplete after all statements executed",
                            "last_error_type": "RevisionLimitReached",
                            "success": False,
                            "scene": scene_info
                        })
                        return current_code

                    last_step_payload = {
                        "step_index": len(executed_stmt_texts),
                        "skill_call_count": skill_call_count,
                        "statement": executed_stmt_texts[-1] if executed_stmt_texts else "",
                        "skill_name": None,
                        "success": False,
                        "need_more_steps": True,
                    }
                    new_code = provider.revise_on_step(
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        next_index=len(executed_stmt_texts),
                        scene=scene_info,
                        last_step_payload=last_step_payload,
                    )
                    if new_code:
                        # Apply post-processing to revised code
                        from utils.code_postprocess import apply_code_postprocessing
                        new_code = apply_code_postprocessing(new_code, scene_info)
                        # Only continue if code actually changed
                        if new_code.strip() != current_code.strip():
                            print(f"[REVISE] Additional code generated (attempt {repair_attempt_count}/{max_repair_attempts})")
                            current_code = new_code
                            current_lines = new_code.split('\n')
                            continue
                        else:
                            print(f"[REVISE] Code unchanged after revision, stopping.")

                output_json({
                    "type": "complete",
                    "total_steps": len(executed_stmt_texts),
                    "total_skill_calls": skill_call_count,
                    "success": False,
                    "scene": scene_info
                })
                return current_code

            stmt = statements_to_execute[0]
            statement_index = next_index + 1
            stmt_code = get_statement_code(stmt, current_lines)
            is_skill, skill_name = is_skill_call(stmt)

            # Skip pure return statements - they cannot be executed outside a function context
            if isinstance(stmt, ast.Return):
                print(f"[STEP {statement_index}] Skipping return statement (end of run_skill)")
                executed_stmt_texts.append(stmt_code)
                # Treat as end of execution - check final success
                success_check, _ = task._task.success()
                scene_info = get_scene_info(task, task_name, robot_type)
                output_json({
                    "type": "complete",
                    "total_steps": len(executed_stmt_texts),
                    "total_skill_calls": skill_call_count,
                    "success": bool(success_check),
                    "scene": scene_info
                })
                return current_code

            # Handle statements containing return (e.g., if blocks with return inside)
            if contains_return(stmt):
                print(f"[STEP {statement_index}] Removing return statements from: {stmt_code[:80]}...")
                stmt = remove_returns_from_stmt(stmt)
                ast.fix_missing_locations(stmt)

            # ===== START BACKGROUND VALIDATION THREAD =====
            # Validate future statements while current statement executes
            # Only run background validation if current statement is a primitive skill
            future_statements = statements_to_execute[1:]  # Skip current statement

            if provider != None and len(future_statements) > 0 and is_skill:
                # Wait for any previous validation thread to complete
                if validation_thread != None and validation_thread.is_alive():
                    validation_thread.join(timeout=0.5)

                # ===== CAPTURE SCENE INFO IN MAIN THREAD =====
                # IMPORTANT: CoppeliaSim is NOT thread-safe. We must capture scene_info
                # here in the main thread BEFORE starting the background validation thread.
                # Calling get_scene_info from background thread causes signal 11 (segfault).
                pre_captured_scene_info = get_scene_info(task, task_name, robot_type)

                # Create a copy of tracker state for background thread
                # (to avoid race conditions with main thread updates)
                tracker_copy = ProjectedStateTracker()
                if state_tracker.confirmed_state:
                    tracker_copy.confirmed_state = state_tracker.confirmed_state.copy()
                    tracker_copy.projected_state = state_tracker.confirmed_state.copy()

                # Prepare current statement info for projection
                # (its effect should be applied first, assuming success)
                current_stmt_info = None
                if is_skill:
                    call_skill_name, call_node = extract_skill_call_info(stmt)
                    if call_skill_name and call_node:
                        try:
                            args_eval, kwargs_eval = eval_call_arguments(call_node, context)
                            current_stmt_info = {
                                "skill_name": call_skill_name,
                                "target_pos": extract_target_pos(args_eval, kwargs_eval),
                                "target_obj": extract_target_obj(kwargs_eval, args_eval),
                            }
                        except Exception:
                            # If we can't evaluate, just pass skill name
                            current_stmt_info = {"skill_name": call_skill_name}

                # Create a safe context copy for background thread
                # IMPORTANT: Remove PyRep objects (Shape, Joint, etc.) to prevent
                # CoppeliaSim access from background thread (causes segfault)
                # ALSO: Deep copy mutable objects (numpy arrays) to prevent
                # background validation from modifying main thread's variables
                safe_context = {}
                for k, v in context.items():
                    if k in ('Shape', 'Joint', 'Dummy', 'ProximitySensor', 'env', 'task'):
                        continue
                    # Deep copy numpy arrays to prevent background thread from modifying originals
                    if isinstance(v, np.ndarray):
                        safe_context[k] = v.copy()
                    elif isinstance(v, list):
                        # Copy list with numpy arrays copied
                        safe_context[k] = [x.copy() if isinstance(x, np.ndarray) else x for x in v]
                    else:
                        # For other types (functions, modules, scalars), just reference
                        # Don't use deepcopy - it fails on PyRep/PyCapsule objects
                        safe_context[k] = v

                validation_thread = threading.Thread(
                    target=background_validate_future_statements,
                    args=(
                        future_statements,
                        current_lines.copy(),
                        safe_context,  # Use safe context without PyRep objects
                        provider,
                        current_code,
                        executed_stmt_texts.copy(),
                        next_index + 1,  # Future starts after current
                        tracker_copy,
                        current_stmt_info,  # Pass current statement info
                        pre_captured_scene_info,  # Pass pre-captured scene info (thread-safe)
                    ),
                    daemon=True,
                )
                validation_thread.start()
                print(f"[BG_VALIDATE] Started validating {len(future_statements)} future statements"
                      f"{' (with current: ' + current_stmt_info['skill_name'] + ')' if current_stmt_info else ''}")

            if is_skill:
                print(f"\n[STEP {statement_index}] Executing: {stmt_code[:100]}...")

            try:
                # Evaluate arguments for condition checks (best-effort)
                pre_obs = None
                precondition_result = None
                postcondition_result = None
                target_pos_for_checks = None
                target_obj_for_checks = None
                call_skill_name, call_node = extract_skill_call_info(stmt)

                if is_skill and call_node != None:
                    pre_obs = task.get_observation()
                    try:
                        args_eval, kwargs_eval = eval_call_arguments(call_node, context)
                        target_pos_for_checks = extract_target_pos(args_eval, kwargs_eval)
                        target_obj_for_checks = extract_target_obj(kwargs_eval, args_eval)
                        precondition_result = check_preconditions(
                            skill_name,
                            pre_obs,
                            target_pos=target_pos_for_checks,
                            held_obj=state_tracker.confirmed_state.gripper.held_object if state_tracker.confirmed_state else None,
                        )
                    except Exception as cond_err:
                        precondition_result = {"success": False, "error": str(cond_err)}

                # Execute the statement
                stmt_module = ast.Module(body=[stmt], type_ignores=[])
                ast.fix_missing_locations(stmt_module)
                compiled = compile(stmt_module, '<string>', 'exec')
                exec(compiled, context)

                executed_stmt_texts.append(stmt_code)

                # ===== STORE POST-EXECUTION STATE =====
                # After each statement finishes, store the environment state from observation
                post_obs = task.get_observation()

                if is_skill:
                    # Confirm state with observation data (trusted)
                    state_tracker.confirm_state(
                        obs=post_obs,
                        statement_index=statement_index,
                        skill_name=skill_name,
                        target_obj=target_obj_for_checks,
                    )
                    print(f"[STATE] Confirmed state after step {statement_index}: "
                          f"gripper_open={state_tracker.confirmed_state.gripper.is_open}, "
                          f"held_obj={state_tracker.confirmed_state.gripper.held_object}")

                    # Postcondition check using new observation
                    try:
                        postcondition_result = check_postconditions(
                            skill_name,
                            pre_obs,
                            post_obs,
                            target_pos=target_pos_for_checks,
                            held_obj=state_tracker.confirmed_state.gripper.held_object,
                        )
                    except Exception as cond_err:
                        postcondition_result = {"success": False, "error": str(cond_err)}

                    skill_call_count += 1
                    scene_info = get_scene_info(task, task_name, robot_type)
                    success_check, _ = task._task.success()

                    # Update object positions in tracker from scene
                    for obj_info in scene_info.get("objects", []):
                        if obj_info["name"] in state_tracker.confirmed_state.objects:
                            if obj_info.get("position"):
                                state_tracker.confirmed_state.objects[obj_info["name"]].position = np.array(obj_info["position"])

                    last_step_payload = {
                        "step_index": statement_index,
                        "skill_call_count": skill_call_count,
                        "statement": stmt_code,
                        "skill_name": skill_name,
                        "success": success_check,
                        "precondition": precondition_result,
                        "postcondition": postcondition_result,
                        "confirmed_state": {
                            "gripper_open": state_tracker.confirmed_state.gripper.is_open,
                            "held_object": state_tracker.confirmed_state.gripper.held_object,
                        }
                    }

                    output_json({
                        "type": "step",
                        **last_step_payload,
                        "scene": scene_info
                    })

                    # Check if background validation completed with repair
                    if apply_pending_repair():
                        continue  # Restart loop with repaired code

                    if success_check:
                        print(f"[SUCCESS] Task completed at step {statement_index}")
                        output_json({
                            "type": "complete",
                            "step_index": statement_index,
                            "skill_call_count": skill_call_count,
                            "success": True,
                            "scene": scene_info
                        })
                        return current_code

                    if max_skill_calls != None and skill_call_count >= max_skill_calls:
                        print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                        output_json({
                            "type": "max_steps_reached",
                            "step_index": statement_index,
                            "skill_call_count": skill_call_count,
                            "max_skill_calls": max_skill_calls,
                            "success": success_check,
                            "scene": scene_info
                        })
                        return current_code

                    if provider != None:
                        new_code = provider.revise_on_step(
                            current_code=current_code,
                            executed_stmt_texts=executed_stmt_texts,
                            next_index=len(executed_stmt_texts),
                            scene=scene_info,
                            last_step_payload=last_step_payload,
                        )
                        if new_code:
                            # Apply post-processing to revised code
                            from utils.code_postprocess import apply_code_postprocessing
                            new_code = apply_code_postprocessing(new_code, scene_info)
                            # Only continue if code actually changed
                            if new_code.strip() != current_code.strip():
                                repair_attempt_count += 1
                                if repair_attempt_count > max_repair_attempts:
                                    print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                                    output_json({
                                        "type": "max_repairs_reached",
                                        "step_index": statement_index,
                                        "repair_attempt_count": repair_attempt_count,
                                        "max_repair_attempts": max_repair_attempts,
                                        "last_error": "Revision limit reached",
                                        "last_error_type": "RevisionLimitReached",
                                        "success": False,
                                        "scene": scene_info
                                    })
                                    return current_code
                                print(f"[REVISE] Code revised after step {statement_index} (attempt {repair_attempt_count}/{max_repair_attempts})")
                                current_code = new_code
                                current_lines = new_code.split('\n')
                                continue
                            else:
                                print(f"[REVISE] Code unchanged after revision, stopping revise loop.")

                else:
                    # Non-skill statement: just check for pending repairs
                    if apply_pending_repair():
                        continue

            except BaseException as e:
                #        exception repair  
                tb = traceback.format_exc()
                error_msg = getattr(e, 'message', str(e))

                # scene_info   (  )
                try:
                    scene_info = get_scene_info(task, task_name, robot_type)
                except Exception:
                    scene_info = {}

                # Check gripper alignment with graspable objects
                alignment_guidance = None
                gripper_quat = scene_info.get("gripper", {}).get("quaternion")
                if gripper_quat is not None:
                    try:
                        alignment_guidance = check_gripper_alignment_for_scene(
                            np.array(gripper_quat),
                            scene_info.get("objects", []),
                            task_name,
                            robot_type
                        )
                    except Exception:
                        pass

                #  skill failure  
                is_known_failure = isinstance(e, (SkillFailure, PathOutOfWorkspace, InvalidActionError))
                failure_type = "failure" if is_known_failure else "error"

                failure_payload = {
                    "type": failure_type,
                    "step_index": statement_index,
                    "statement": stmt_code,
                    "skill_name": skill_name if is_skill else None,
                    "error": error_msg,
                    "error_type": type(e).__name__,
                    "scene": scene_info
                }
                if not is_known_failure:
                    failure_payload["traceback"] = tb
                if alignment_guidance:
                    failure_payload["alignment_guidance"] = alignment_guidance

                output_json(failure_payload)
                print(f"[{failure_type.upper()}] {type(e).__name__}: {error_msg}")
                if alignment_guidance:
                    print(f"[ALIGNMENT] {alignment_guidance}")

                # repair 
                if provider is not None:
                    repair_attempt_count += 1
                    if repair_attempt_count > max_repair_attempts:
                        print(f"[MAX_REPAIRS] Reached maximum repair attempts: {max_repair_attempts}")
                        output_json({
                            "type": "max_repairs_reached",
                            "step_index": statement_index,
                            "repair_attempt_count": repair_attempt_count,
                            "max_repair_attempts": max_repair_attempts,
                            "last_error": error_msg,
                            "last_error_type": type(e).__name__,
                            "success": False,
                            "scene": scene_info
                        })
                        return current_code

                    try:
                        error_payload = {"error": error_msg, "error_type": type(e).__name__}
                        if not is_known_failure:
                            error_payload["traceback"] = tb
                        new_code = provider.repair_on_failure(
                            current_code=current_code,
                            executed_stmt_texts=executed_stmt_texts,
                            failed_stmt_text=stmt_code,
                            error_payload=error_payload,
                            scene=scene_info,
                        )
                        if new_code:
                            # Apply post-processing to repaired code
                            from utils.code_postprocess import apply_code_postprocessing
                            new_code = apply_code_postprocessing(new_code, scene_info)
                            # Validate repaired code syntax before applying
                            try:
                                ast.parse(new_code)
                                print(f"[REPAIR] Code repaired after {type(e).__name__} (attempt {repair_attempt_count}/{max_repair_attempts})")
                                current_code = new_code
                                current_lines = new_code.split('\n')
                                continue
                            except SyntaxError as syn_err:
                                print(f"[REPAIR_INVALID] Repaired code has syntax error: {syn_err}")
                                # Don't apply invalid code, continue with next repair attempt
                    except Exception as repair_error:
                        print(f"[REPAIR_FAILED] {repair_error}")

                return current_code

    finally:
        # Wait for any running validation thread to complete
        # Use longer timeout to ensure LLM calls finish before provider.close()
        if validation_thread != None and validation_thread.is_alive():
            print("[INFO] Waiting for background validation thread to complete...")
            validation_thread.join(timeout=120.0)  # 2  (LLM   )
            if validation_thread.is_alive():
                print("[WARNING] Background validation thread did not complete in time")


# ============================================================================
# Step Mode Executor
# ============================================================================

def execute_step_mode(
    env, task, task_name: str, robot_type: str,
    source_code: str, source_lines: List[str],
    max_skill_calls: Optional[int] = None,
    provider: Optional["PolicyProvider"] = None,
    max_repair_attempts: int = 5,
    use_text_matching: bool = True
) -> Optional[str]:
    """
    Execute policy in 'step' mode.

    Args:
        use_text_matching: If True, use text-based matching to skip executed statements
                          after repair. If False, use index-based matching (legacy).

    Returns:
        Final code after all repairs and revisions (or None on early error)
    """
    current_code = source_code
    current_lines = source_lines
    executed_stmt_texts: List[str] = []
    skill_call_count = 0
    repair_attempt_count = 0  # Track repair attempts to prevent infinite loops

    context = create_execution_context(env, task, robot_type)

    # Add run_skill function arguments to context (so they can be referenced in code)
    context['descriptions'] = None
    context['obs'] = task.get_observation()
    context['variations_index'] = 0

    # Pre-process: Fix object type mismatches before execution
    # This prevents WrongObjectTypeError (e.g., Shape('success') when it's ProximitySensor)
    try:
        from utils.code_postprocess import fix_object_types_from_scene, remove_unnecessary_success_checks
        initial_scene_for_type_fix = get_scene_info(task, task_name, robot_type)
        current_code = fix_object_types_from_scene(current_code, initial_scene_for_type_fix)
        current_code = remove_unnecessary_success_checks(current_code)
        current_lines = current_code.split('\n')
        print("[POSTPROCESS] Applied object type auto-correction and removed unnecessary success checks")
    except Exception as e:
        print(f"[POSTPROCESS] Warning: Code postprocessing failed: {e}")

    last_valid_code = source_code  #    
    last_valid_lines = source_lines
    last_valid_executed_stmts: List[str] = []

    while True:
        #   (  , orientation )
        current_code = _sanitize_code(current_code)
        current_lines = current_code.split('\n')

        try:
            tree = ast.parse(current_code)
        except SyntaxError as e:
            print(f"[ERROR] SyntaxError in code: {e}")
            repair_attempt_count += 1
            if repair_attempt_count >= max_repair_attempts:
                print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                return last_valid_code
            print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
            current_code = last_valid_code
            current_lines = last_valid_lines
            executed_stmt_texts = list(last_valid_executed_stmts)
            continue

        extractor = RunSkillExtractor()
        extractor.visit(tree)

        if not extractor.run_skill_body:
            repair_attempt_count += 1
            print(f"[ERROR] run_skill    . ( {repair_attempt_count}/{max_repair_attempts})")
            if repair_attempt_count >= max_repair_attempts:
                print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                return last_valid_code
            print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
            current_code = last_valid_code
            current_lines = last_valid_lines
            executed_stmt_texts = list(last_valid_executed_stmts)
            continue

        #   -    
        last_valid_code = current_code
        last_valid_lines = current_lines
        last_valid_executed_stmts = list(executed_stmt_texts)

        for _, func_node in extractor.helper_functions.items():
            func_code = compile(ast.Module(body=[func_node], type_ignores=[]), '<string>', 'exec')
            exec(func_code, context)

        statements_to_execute, next_index = get_statements_to_execute_by_text_matching(
            extractor.run_skill_body, executed_stmt_texts, current_lines, use_text_matching
        )

        if not statements_to_execute:
            success_check, _ = task._task.success()
            scene_info = get_scene_info(task, task_name, robot_type)

            if success_check:
                output_json({
                    "type": "complete",
                    "total_steps": len(executed_stmt_texts),
                    "total_skill_calls": skill_call_count,
                    "success": True,
                    "scene": scene_info
                })
                return current_code

            if max_skill_calls != None and skill_call_count >= max_skill_calls:
                print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                output_json({
                    "type": "max_steps_reached",
                    "total_steps": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "max_skill_calls": max_skill_calls,
                    "success": False,
                    "scene": scene_info
                })
                return current_code

            print(f"[INFO] All statements executed but task incomplete. Requesting additional code...")
            if provider != None:
                repair_attempt_count += 1
                if repair_attempt_count > max_repair_attempts:
                    print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                    output_json({
                        "type": "max_repairs_reached",
                        "step_index": len(executed_stmt_texts),
                        "repair_attempt_count": repair_attempt_count,
                        "max_repair_attempts": max_repair_attempts,
                        "last_error": "Task incomplete after all statements executed",
                        "last_error_type": "RevisionLimitReached",
                        "success": False,
                        "scene": scene_info
                    })
                    return current_code

                last_step_payload = {
                    "step_index": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "statement": executed_stmt_texts[-1] if executed_stmt_texts else "",
                    "skill_name": None,
                    "success": False,
                    "need_more_steps": True,
                }
                new_code = provider.revise_on_step(
                    current_code=current_code,
                    executed_stmt_texts=executed_stmt_texts,
                    next_index=len(executed_stmt_texts),
                    scene=scene_info,
                    last_step_payload=last_step_payload,
                )
                if new_code:
                    # Apply post-processing to revised code
                    from utils.code_postprocess import apply_code_postprocessing
                    new_code = apply_code_postprocessing(new_code, scene_info)
                    print(f"[REVISE] Additional code generated (attempt {repair_attempt_count}/{max_repair_attempts})")
                    current_code = new_code
                    current_lines = new_code.split('\n')
                    continue

            output_json({
                "type": "complete",
                "total_steps": len(executed_stmt_texts),
                "total_skill_calls": skill_call_count,
                "success": False,
                "scene": scene_info
            })
            return current_code

        stmt = statements_to_execute[0]
        statement_index = next_index + 1
        stmt_code = get_statement_code(stmt, current_lines)
        is_skill, skill_name = is_skill_call(stmt)

        # Skip pure return statements - they cannot be executed outside a function context
        if isinstance(stmt, ast.Return):
            print(f"[STEP {statement_index}] Skipping return statement (end of run_skill)")
            executed_stmt_texts.append(stmt_code)
            # Treat as end of execution - check final success
            success_check, _ = task._task.success()
            scene_info = get_scene_info(task, task_name, robot_type)
            output_json({
                "type": "complete",
                "total_steps": len(executed_stmt_texts),
                "total_skill_calls": skill_call_count,
                "success": bool(success_check),
                "scene": scene_info
            })
            return current_code

        # Handle statements containing return (e.g., if blocks with return inside)
        if contains_return(stmt):
            print(f"[STEP {statement_index}] Removing return statements from: {stmt_code[:80]}...")
            stmt = remove_returns_from_stmt(stmt)
            ast.fix_missing_locations(stmt)

        if is_skill:
            print(f"\n[STEP {statement_index}] Executing: {stmt_code[:100]}...")

        try:
            stmt_module = ast.Module(body=[stmt], type_ignores=[])
            ast.fix_missing_locations(stmt_module)
            compiled = compile(stmt_module, '<string>', 'exec')
            exec(compiled, context)

            executed_stmt_texts.append(stmt_code)

            if is_skill:
                skill_call_count += 1
                scene_info = get_scene_info(task, task_name, robot_type)
                success_check, _ = task._task.success()

                last_step_payload = {
                    "step_index": statement_index,
                    "skill_call_count": skill_call_count,
                    "statement": stmt_code,
                    "skill_name": skill_name,
                    "success": success_check,
                }

                output_json({
                    "type": "step",
                    **last_step_payload,
                    "scene": scene_info
                })

                if success_check:
                    print(f"[SUCCESS] Task completed at step {statement_index}")
                    output_json({
                        "type": "complete",
                        "step_index": statement_index,
                        "skill_call_count": skill_call_count,
                        "success": True,
                        "scene": scene_info
                    })
                    return current_code

                if max_skill_calls != None and skill_call_count >= max_skill_calls:
                    print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                    output_json({
                        "type": "max_steps_reached",
                        "step_index": statement_index,
                        "skill_call_count": skill_call_count,
                        "max_skill_calls": max_skill_calls,
                        "success": success_check,
                        "scene": scene_info
                    })
                    return current_code

                if provider != None:
                    new_code = provider.revise_on_step(
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        next_index=len(executed_stmt_texts),
                        scene=scene_info,
                        last_step_payload=last_step_payload,
                    )
                    if new_code:
                        # Apply post-processing to revised code
                        from utils.code_postprocess import apply_code_postprocessing
                        new_code = apply_code_postprocessing(new_code, scene_info)
                        # Only increment counter when code actually changed
                        repair_attempt_count += 1
                        if repair_attempt_count > max_repair_attempts:
                            print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                            output_json({
                                "type": "max_repairs_reached",
                                "step_index": statement_index,
                                "repair_attempt_count": repair_attempt_count,
                                "max_repair_attempts": max_repair_attempts,
                                "last_error": "Revision limit reached",
                                "last_error_type": "RevisionLimitReached",
                                "success": False,
                                "scene": scene_info
                            })
                            return current_code
                        print(f"[REVISE] Code revised after step {statement_index} (attempt {repair_attempt_count}/{max_repair_attempts})")
                        current_code = new_code
                        current_lines = new_code.split('\n')

        except BaseException as e:
            #        exception repair  
            tb = traceback.format_exc()
            error_msg = getattr(e, 'message', str(e))

            # scene_info   (  )
            try:
                scene_info = get_scene_info(task, task_name, robot_type)
            except Exception:
                scene_info = {}

            # Check gripper alignment with graspable objects
            alignment_guidance = None
            gripper_quat = scene_info.get("gripper", {}).get("quaternion")
            if gripper_quat is not None:
                try:
                    alignment_guidance = check_gripper_alignment_for_scene(
                        np.array(gripper_quat),
                        scene_info.get("objects", []),
                        task_name,
                        robot_type
                    )
                except Exception:
                    pass

            #  skill failure  
            is_known_failure = isinstance(e, (SkillFailure, PathOutOfWorkspace, InvalidActionError))
            failure_type = "failure" if is_known_failure else "error"

            failure_payload = {
                "type": failure_type,
                "step_index": statement_index,
                "statement": stmt_code,
                "skill_name": skill_name if is_skill else None,
                "error": error_msg,
                "error_type": type(e).__name__,
                "scene": scene_info
            }
            if not is_known_failure:
                failure_payload["traceback"] = tb
            if alignment_guidance:
                failure_payload["alignment_guidance"] = alignment_guidance

            output_json(failure_payload)
            print(f"[{failure_type.upper()}] {type(e).__name__}: {error_msg}")
            if alignment_guidance:
                print(f"[ALIGNMENT] {alignment_guidance}")

            # repair 
            if provider is not None:
                repair_attempt_count += 1
                if repair_attempt_count > max_repair_attempts:
                    print(f"[MAX_REPAIRS] Reached maximum repair attempts: {max_repair_attempts}")
                    output_json({
                        "type": "max_repairs_reached",
                        "step_index": statement_index,
                        "repair_attempt_count": repair_attempt_count,
                        "max_repair_attempts": max_repair_attempts,
                        "last_error": error_msg,
                        "last_error_type": type(e).__name__,
                        "success": False,
                        "scene": scene_info
                    })
                    return current_code

                try:
                    error_payload = {"error": error_msg, "error_type": type(e).__name__}
                    if not is_known_failure:
                        error_payload["traceback"] = tb
                    new_code = provider.repair_on_failure(
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        failed_stmt_text=stmt_code,
                        error_payload=error_payload,
                        scene=scene_info,
                    )
                    if new_code:
                        # Apply post-processing to repaired code
                        from utils.code_postprocess import apply_code_postprocessing
                        new_code = apply_code_postprocessing(new_code, scene_info)
                        # Validate repaired code syntax before applying
                        try:
                            ast.parse(new_code)
                            print(f"[REPAIR] Code repaired after {type(e).__name__} (attempt {repair_attempt_count}/{max_repair_attempts})")
                            current_code = new_code
                            current_lines = new_code.split('\n')
                            continue
                        except SyntaxError as syn_err:
                            print(f"[REPAIR_INVALID] Repaired code has syntax error: {syn_err}")
                            # Don't apply invalid code, continue with next repair attempt
                except Exception as repair_error:
                    print(f"[REPAIR_FAILED] {repair_error}")

            return current_code


# ============================================================================
# Failure Mode Executor
# ============================================================================

def execute_failure_mode(
    env, task, task_name: str, robot_type: str,
    source_code: str, source_lines: List[str],
    max_skill_calls: Optional[int] = None,
    provider: Optional["PolicyProvider"] = None,
    max_repair_attempts: int = 5,
    use_text_matching: bool = True
) -> Optional[str]:
    """
    Execute policy in 'failure' mode.

    Args:
        use_text_matching: If True, use text-based matching to skip executed statements
                          after repair. If False, use index-based matching (legacy).

    Returns:
        Final code after all repairs and revisions (or None on early error)
    """
    current_code = source_code
    current_lines = source_lines
    executed_stmt_texts: List[str] = []
    skill_call_count = 0
    repair_attempt_count = 0  # Track repair attempts to prevent infinite loops

    context = create_execution_context(env, task, robot_type)

    # Add run_skill function arguments to context (so they can be referenced in code)
    context['descriptions'] = None
    context['obs'] = task.get_observation()
    context['variations_index'] = 0

    # Pre-process: Fix object type mismatches before execution
    # This prevents WrongObjectTypeError (e.g., Shape('success') when it's ProximitySensor)
    try:
        from utils.code_postprocess import fix_object_types_from_scene, remove_unnecessary_success_checks
        initial_scene_for_type_fix = get_scene_info(task, task_name, robot_type)
        current_code = fix_object_types_from_scene(current_code, initial_scene_for_type_fix)
        current_code = remove_unnecessary_success_checks(current_code)
        current_lines = current_code.split('\n')
        print("[POSTPROCESS] Applied object type auto-correction and removed unnecessary success checks")
    except Exception as e:
        print(f"[POSTPROCESS] Warning: Code postprocessing failed: {e}")

    last_valid_code = source_code  #    
    last_valid_lines = source_lines
    last_valid_executed_stmts: List[str] = []

    while True:
        #   (  , orientation )
        current_code = _sanitize_code(current_code)
        current_lines = current_code.split('\n')

        try:
            tree = ast.parse(current_code)
        except SyntaxError as e:
            print(f"[ERROR] SyntaxError in code: {e}")
            repair_attempt_count += 1
            if repair_attempt_count >= max_repair_attempts:
                print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                return last_valid_code
            print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
            current_code = last_valid_code
            current_lines = last_valid_lines
            executed_stmt_texts = list(last_valid_executed_stmts)
            continue

        extractor = RunSkillExtractor()
        extractor.visit(tree)

        if not extractor.run_skill_body:
            repair_attempt_count += 1
            print(f"[ERROR] run_skill    . ( {repair_attempt_count}/{max_repair_attempts})")
            if repair_attempt_count >= max_repair_attempts:
                print(f"[ERROR] Max repair attempts ({max_repair_attempts}) exceeded, returning last valid code")
                return last_valid_code
            print(f"[RESET] Reverting to last valid code and retrying... (attempt {repair_attempt_count}/{max_repair_attempts})")
            current_code = last_valid_code
            current_lines = last_valid_lines
            executed_stmt_texts = list(last_valid_executed_stmts)
            continue

        #   -    
        last_valid_code = current_code
        last_valid_lines = current_lines
        last_valid_executed_stmts = list(executed_stmt_texts)

        for _, func_node in extractor.helper_functions.items():
            func_code = compile(ast.Module(body=[func_node], type_ignores=[]), '<string>', 'exec')
            exec(func_code, context)

        statements_to_execute, next_index = get_statements_to_execute_by_text_matching(
            extractor.run_skill_body, executed_stmt_texts, current_lines, use_text_matching
        )

        if not statements_to_execute:
            success_check, _ = task._task.success()
            scene_info = get_scene_info(task, task_name, robot_type)

            if success_check:
                output_json({
                    "type": "complete",
                    "total_steps": len(executed_stmt_texts),
                    "total_skill_calls": skill_call_count,
                    "success": True,
                    "scene": scene_info
                })
                return current_code

            if max_skill_calls != None and skill_call_count >= max_skill_calls:
                print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                output_json({
                    "type": "max_steps_reached",
                    "total_steps": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "max_skill_calls": max_skill_calls,
                    "success": False,
                    "scene": scene_info
                })
                return current_code

            print(f"[INFO] All statements executed but task incomplete. Requesting additional code...")
            if provider != None:
                repair_attempt_count += 1
                if repair_attempt_count > max_repair_attempts:
                    print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                    output_json({
                        "type": "max_repairs_reached",
                        "step_index": len(executed_stmt_texts),
                        "repair_attempt_count": repair_attempt_count,
                        "max_repair_attempts": max_repair_attempts,
                        "last_error": "Task incomplete after all statements executed",
                        "last_error_type": "RevisionLimitReached",
                        "success": False,
                        "scene": scene_info
                    })
                    return current_code

                last_step_payload = {
                    "step_index": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "statement": executed_stmt_texts[-1] if executed_stmt_texts else "",
                    "skill_name": None,
                    "success": False,
                    "need_more_steps": True,
                }
                new_code = provider.revise_on_step(
                    current_code=current_code,
                    executed_stmt_texts=executed_stmt_texts,
                    next_index=len(executed_stmt_texts),
                    scene=scene_info,
                    last_step_payload=last_step_payload,
                )
                if new_code:
                    # Apply post-processing to revised code
                    from utils.code_postprocess import apply_code_postprocessing
                    new_code = apply_code_postprocessing(new_code, scene_info)
                    print(f"[REVISE] Additional code generated (attempt {repair_attempt_count}/{max_repair_attempts})")
                    current_code = new_code
                    current_lines = new_code.split('\n')
                    print("current_line:\n", current_lines)
                    continue

            output_json({
                "type": "complete",
                "total_steps": len(executed_stmt_texts),
                "total_skill_calls": skill_call_count,
                "success": False,
                "scene": scene_info
            })
            return current_code

        repair_needed = False

        for statement_index, stmt in enumerate(statements_to_execute, start=next_index + 1):
            stmt_code = get_statement_code(stmt, current_lines)
            is_skill, skill_name = is_skill_call(stmt)

            # Skip pure return statements - they cannot be executed outside a function context
            if isinstance(stmt, ast.Return):
                print(f"[STEP {statement_index}] Skipping return statement (end of run_skill)")
                executed_stmt_texts.append(stmt_code)
                next_index += 1
                # Treat as end of execution - check final success after loop
                break

            # Handle statements containing return (e.g., if blocks with return inside)
            if contains_return(stmt):
                print(f"[STEP {statement_index}] Removing return statements from: {stmt_code[:80]}...")
                stmt = remove_returns_from_stmt(stmt)
                ast.fix_missing_locations(stmt)

            try:
                stmt_module = ast.Module(body=[stmt], type_ignores=[])
                ast.fix_missing_locations(stmt_module)
                compiled = compile(stmt_module, '<string>', 'exec')
                exec(compiled, context)

                executed_stmt_texts.append(stmt_code)
                next_index += 1

                if is_skill:
                    skill_call_count += 1
                    success_check, _ = task._task.success()
                    if success_check:
                        scene_info = get_scene_info(task, task_name, robot_type)
                        output_json({
                            "type": "complete",
                            "step_index": statement_index,
                            "skill_call_count": skill_call_count,
                            "success": True,
                            "scene": scene_info
                        })
                        print(f"[SUCCESS] Task completed at step {statement_index}")
                        return current_code

                    if max_skill_calls != None and skill_call_count >= max_skill_calls:
                        scene_info = get_scene_info(task, task_name, robot_type)
                        print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                        output_json({
                            "type": "max_steps_reached",
                            "step_index": statement_index,
                            "skill_call_count": skill_call_count,
                            "max_skill_calls": max_skill_calls,
                            "success": success_check,
                            "scene": scene_info
                        })
                        repair_needed = True
                        break

            except BaseException as e:
                #        exception repair  
                tb = traceback.format_exc()
                error_msg = getattr(e, 'message', str(e))

                # scene_info   (  )
                try:
                    scene_info = get_scene_info(task, task_name, robot_type)
                except Exception:
                    scene_info = {}

                # Check gripper alignment with graspable objects
                alignment_guidance = None
                gripper_quat = scene_info.get("gripper", {}).get("quaternion")
                if gripper_quat is not None:
                    try:
                        alignment_guidance = check_gripper_alignment_for_scene(
                            np.array(gripper_quat),
                            scene_info.get("objects", []),
                            task_name,
                            robot_type
                        )
                    except Exception:
                        pass

                #  skill failure  
                is_known_failure = isinstance(e, (SkillFailure, PathOutOfWorkspace, InvalidActionError))
                failure_type = "failure" if is_known_failure else "error"

                failure_payload = {
                    "type": failure_type,
                    "step_index": statement_index,
                    "statement": stmt_code,
                    "skill_name": skill_name if is_skill else None,
                    "error": error_msg,
                    "error_type": type(e).__name__,
                    "scene": scene_info
                }
                if not is_known_failure:
                    failure_payload["traceback"] = tb
                if alignment_guidance:
                    failure_payload["alignment_guidance"] = alignment_guidance

                output_json(failure_payload)
                print(f"[{failure_type.upper()}] {type(e).__name__}: {error_msg}")
                if alignment_guidance:
                    print(f"[ALIGNMENT] {alignment_guidance}")

                # repair 
                if provider is not None:
                    repair_attempt_count += 1
                    if repair_attempt_count > max_repair_attempts:
                        print(f"[MAX_REPAIRS] Reached maximum repair attempts: {max_repair_attempts}")
                        output_json({
                            "type": "max_repairs_reached",
                            "step_index": statement_index,
                            "repair_attempt_count": repair_attempt_count,
                            "max_repair_attempts": max_repair_attempts,
                            "last_error": error_msg,
                            "last_error_type": type(e).__name__,
                            "success": False,
                            "scene": scene_info
                        })
                        return current_code

                    try:
                        error_payload = {"error": error_msg, "error_type": type(e).__name__}
                        if not is_known_failure:
                            error_payload["traceback"] = tb
                        new_code = provider.repair_on_failure(
                            current_code=current_code,
                            executed_stmt_texts=executed_stmt_texts,
                            failed_stmt_text=stmt_code,
                            error_payload=error_payload,
                            scene=scene_info,
                        )
                        if new_code:
                            # Apply post-processing to repaired code
                            from utils.code_postprocess import apply_code_postprocessing
                            new_code = apply_code_postprocessing(new_code, scene_info)
                            # Validate repaired code syntax before applying
                            try:
                                ast.parse(new_code)
                                print(f"[REPAIR] Code repaired after {type(e).__name__} (attempt {repair_attempt_count}/{max_repair_attempts})")
                                current_code = new_code
                                current_lines = new_code.split('\n')
                                repair_needed = True
                                break
                            except SyntaxError as syn_err:
                                print(f"[REPAIR_INVALID] Repaired code has syntax error: {syn_err}")
                                # Don't apply invalid code, continue with next repair attempt
                    except Exception as repair_error:
                        print(f"[REPAIR_FAILED] {repair_error}")

                return current_code

        if not repair_needed:
            success_check, _ = task._task.success()
            scene_info = get_scene_info(task, task_name, robot_type)

            if success_check:
                output_json({
                    "type": "complete",
                    "total_steps": len(executed_stmt_texts),
                    "total_skill_calls": skill_call_count,
                    "success": True,
                    "scene": scene_info
                })
                return current_code

            if max_skill_calls != None and skill_call_count >= max_skill_calls:
                print(f"[MAX_STEPS] Reached maximum skill calls: {max_skill_calls}")
                output_json({
                    "type": "max_steps_reached",
                    "total_steps": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "max_skill_calls": max_skill_calls,
                    "success": False,
                    "scene": scene_info
                })
                return current_code

            print(f"[INFO] All statements executed but task incomplete. Requesting additional code...")
            if provider != None:
                repair_attempt_count += 1
                if repair_attempt_count > max_repair_attempts:
                    print(f"[MAX_REPAIRS] Reached maximum repair/revise attempts: {max_repair_attempts}")
                    output_json({
                        "type": "max_repairs_reached",
                        "step_index": len(executed_stmt_texts),
                        "repair_attempt_count": repair_attempt_count,
                        "max_repair_attempts": max_repair_attempts,
                        "last_error": "Task incomplete after all statements executed",
                        "last_error_type": "RevisionLimitReached",
                        "success": False,
                        "scene": scene_info
                    })
                    return current_code

                last_step_payload = {
                    "step_index": len(executed_stmt_texts),
                    "skill_call_count": skill_call_count,
                    "statement": executed_stmt_texts[-1] if executed_stmt_texts else "",
                    "skill_name": None,
                    "success": False,
                    "need_more_steps": True,
                }
                new_code = provider.revise_on_step(
                    current_code=current_code,
                    executed_stmt_texts=executed_stmt_texts,
                    next_index=len(executed_stmt_texts),
                    scene=scene_info,
                    last_step_payload=last_step_payload,
                )
                if new_code:
                    # Apply post-processing to revised code
                    from utils.code_postprocess import apply_code_postprocessing
                    new_code = apply_code_postprocessing(new_code, scene_info)
                    print(f"[REVISE] Additional code generated (attempt {repair_attempt_count}/{max_repair_attempts})")
                    current_code = new_code
                    current_lines = new_code.split('\n')
                    continue

            output_json({
                "type": "complete",
                "total_steps": len(executed_stmt_texts),
                "total_skill_calls": skill_call_count,
                "success": False,
                "scene": scene_info
            })
            return current_code


# ============================================================================
# Run Task Episode (Callable from external scripts)
# ============================================================================

class EpisodeResult:
    """    """
    def __init__(self):
        self.success: bool = False
        self.execution_time: float = 0.0
        self.skill_call_count: int = 0
        self.total_steps: int = 0
        self.error: Optional[str] = None
        self.error_type: Optional[str] = None
        self.llm_calls: List[Dict[str, Any]] = []  # policy_provider  
        self.final_code: Optional[str] = None  #    (repair )


def run_task_episode(
    task_name: str,
    mode: str = "step",
    source_robot: str = "panda",
    target_robot: str = "panda",
    max_skill_calls: int = 10,
    max_repair_attempts: int = 5,
    code_source: str = "remote_llm",
    model: str = "ours",
    remote_host: str = "172.17.0.1",
    remote_port: int = 5000,
    llm_call_logger: Optional[callable] = None,
) -> EpisodeResult:
    """
       .
    main()  , argument   .

    Args:
        task_name: RLBench task class name (: BasketballInHoop)
        mode:   (step  failure)
        source_robot: Source robot type (panda, ur5, sawyer, jaco)
        target_robot: Target robot type for environment setup
        max_skill_calls: Maximum number of primitive skill calls (None = unlimited)
        max_repair_attempts: Maximum number of repair attempts before giving up (default: 5)
        code_source: Where to get task skill code (static, local_llm, remote_llm)
        model: Model type (code_agent, ours)
        remote_host: Remote LLM server host
        remote_port: Remote LLM server port
        llm_call_logger: Optional callback to log LLM calls (called with dict: action, duration, etc.)

    Returns:
        EpisodeResult:   
    """
    import time as time_module

    result = EpisodeResult()
    start_time = time_module.time()

    source_robot = source_robot.lower()
    target_robot = target_robot.lower()
    code_source = code_source.lower()
    model = model.lower()

    print(f"[INFO] Task: {task_name}")
    print(f"[INFO] Mode: {mode}")
    print(f"[INFO] Source Robot: {source_robot}")
    print(f"[INFO] Target Robot: {target_robot}")
    print(f"[INFO] Max Skill Calls: {max_skill_calls if max_skill_calls else 'unlimited'}")

    # rlbench.tasks task class 
    tasks_module = importlib.import_module("rlbench.tasks")
    try:
        task_cls = getattr(tasks_module, task_name)
    except AttributeError:
        print(f"[ERROR] '{task_name}'     .")
        result.error = f"Task class '{task_name}' not found"
        result.error_type = "TaskNotFound"
        result.execution_time = time_module.time() - start_time
        return result

    #   (target_robot )
    env, task = None, None
    provider = None

    try:
        print(f"[INFO] Setting up environment with {target_robot} robot...")
        env, task = setup_environment(task_cls, target_robot)

        #  
        descriptions, _ = task.reset()

        #   JSON 
        initial_info = get_initial_info(task, task_name, target_robot, descriptions)
        output_json(initial_info)

        policy_file_path = f"./tasks_{source_robot}/{task_name}.py"

        objects_info_str, objects_info, object_names = get_object_names(task, task_name)
        print("object_names: !!!", object_names)
        scene = initial_info.get("scene", {})
        grasp_guidance = initial_info.get("scene", {}).get("grasp_guidance", "")
        graspable_obj_list = initial_info.get("scene", {}).get("graspable_objects", [])
        graspable_obj_list_str = [f"{obj['object_name']}: can_grasp={obj['can_grasp']}, recommended_approach_axis_world={obj['grasp_axis_world'].lower()}" for obj in graspable_obj_list] if graspable_obj_list else []
        if model == 'codex':
            initial_prompt = get_prompt_for_codex(task, task_name, object_names, target_robot, descriptions, model, scene, grasp_guidance, graspable_obj_list_str)
        else:
            initial_prompt = get_prompt(task, task_name, object_names, target_robot, descriptions, model, grasp_guidance)
        
        print("[INFO] Initial prompt prepared.")
        print(initial_prompt)
        print("----- End of Initial Prompt -----")

        # PolicyProvider  (with timing wrapper if logger provided)
        provider = PolicyProvider(
            mode=code_source,
            policy_path=policy_file_path,
            task_name=task_name,
            source_robot=source_robot,
            target_robot=target_robot,
            model=model,
            initial_prompt=initial_prompt,
            descriptions=descriptions,
            remote_host=remote_host,
            remote_port=remote_port,
            target_scene_info=objects_info,
            grasp_guidance=grasp_guidance,
            object_names=object_names
        )

        # LLM     provider   ( )
        provider = _wrap_provider_with_timing(provider, task_name, llm_call_logger, result)

        print(f"[INFO] PolicyProvider initialized with mode: {code_source}")

        #   
        try:
            source_code = provider.get_initial_code()
            #   
            if model == 'codex':
                source_code = postprocess_initial_code_codex(source_code, target_robot)
            elif code_source != 'static':
                source_code = postprocess_initial_code(source_code, target_robot)
            source_lines = source_code.split('\n')
            print(f"[INFO] Loaded policy from: {policy_file_path}")
        except Exception as e:
            print(f"[ERROR] Failed to get initial code: {e}")
            result.error = str(e)
            result.error_type = type(e).__name__
            result.execution_time = time_module.time() - start_time
            return result

        print(source_code)

        #       (  )
        final_code = None
        if model == "ours":
            final_code = execute_ours_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)
        elif mode == "step":
            final_code = execute_step_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)
        else:  # failure mode
            final_code = execute_failure_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)

        #    (repair/revise )
        result.final_code = final_code if final_code else source_code

        #  success  
        success_check, _ = task._task.success()
        result.success = success_check

    except Exception as e:
        print(f"[ERROR] Episode execution failed: {e}")
        traceback.print_exc()
        result.error = str(e)
        result.error_type = type(e).__name__
        result.success = False
    finally:
        # PolicyProvider 
        if provider != None:
            # unwrap if wrapped
            actual_provider = getattr(provider, '_wrapped_provider', provider)
            actual_provider.close()

        #  
        if env != None:
            print("[INFO] Shutting down environment...")
            shutdown_environment(env)
            print("[INFO] Done.")

        result.execution_time = time_module.time() - start_time

    return result


class _TimingProviderWrapper:
    """PolicyProvider  LLM    """

    def __init__(self, provider: PolicyProvider, task_name: str, logger: callable, result: EpisodeResult):
        self._wrapped_provider = provider
        self._task_name = task_name
        self._logger = logger
        self._result = result

    def __getattr__(self, name):
        return getattr(self._wrapped_provider, name)

    def _log_call(self, action: str, subtask: str, duration: float, num_tokens: Optional[int] = None):
        call_info = {
            "task": self._task_name,
            "subtask": subtask,
            "action": action,
            "duration_sec": duration,
        }
        if num_tokens is not None:
            call_info["num_tokens"] = num_tokens
        self._result.llm_calls.append(call_info)
        if self._logger:
            self._logger(call_info)

    def get_initial_code(self) -> str:
        import time as t
        start = t.time()
        code = self._wrapped_provider.get_initial_code()
        duration = t.time() - start
        num_tokens = getattr(self._wrapped_provider, 'last_response_num_tokens', None)
        self._log_call("generate", "initial_code", duration, num_tokens)
        return code

    def revise_on_step(self, current_code, executed_stmt_texts, next_index, scene, last_step_payload) -> Optional[str]:
        import time as t
        start = t.time()
        code = self._wrapped_provider.revise_on_step(current_code, executed_stmt_texts, next_index, scene, last_step_payload)
        duration = t.time() - start
        num_tokens = getattr(self._wrapped_provider, 'last_response_num_tokens', None)
        step_idx = last_step_payload.get("step_index", next_index)
        self._log_call("revise", f"step_{step_idx}", duration, num_tokens)
        return code

    def repair_on_failure(self, current_code, executed_stmt_texts, failed_stmt_text, error_payload, scene) -> str:
        import time as t
        start = t.time()
        code = self._wrapped_provider.repair_on_failure(current_code, executed_stmt_texts, failed_stmt_text, error_payload, scene)
        duration = t.time() - start
        num_tokens = getattr(self._wrapped_provider, 'last_response_num_tokens', None)
        self._log_call("repair", f"failure_{error_payload.get('error_type', 'unknown')}", duration, num_tokens)
        return code

    def prefetch_infeasibility_repair(self, current_code, executed_stmt_texts, infeasible_stmt_text, infeasible_stmt_lineno, infeasibility_info, scene) -> Optional[str]:
        import time as t
        start = t.time()
        code = self._wrapped_provider.prefetch_infeasibility_repair(current_code, executed_stmt_texts, infeasible_stmt_text, infeasible_stmt_lineno, infeasibility_info, scene)
        duration = t.time() - start
        self._log_call("prefetch_repair", f"line_{infeasible_stmt_lineno}", duration)
        return code

    def batch_invalid_repair(self, current_code, executed_stmt_texts, invalid_statements, scene) -> Optional[str]:
        import time as t
        start = t.time()
        code = self._wrapped_provider.batch_invalid_repair(current_code, executed_stmt_texts, invalid_statements, scene)
        duration = t.time() - start
        num_tokens = getattr(self._wrapped_provider, 'last_response_num_tokens', None)
        self._log_call("batch_repair", f"{len(invalid_statements)}_statements", duration, num_tokens)
        return code


def _wrap_provider_with_timing(provider: PolicyProvider, task_name: str, logger: callable, result: EpisodeResult):
    """PolicyProvider    """
    return _TimingProviderWrapper(provider, task_name, logger, result)


# ============================================================================
# Main
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="RLBench Loop - AST  step-by-step  failure-only "
    )
    parser.add_argument(
        "--task", "-t",
        required=True,
        help="RLBench task class name (: BasketballInHoop, PushButton)"
    )
    parser.add_argument(
        "--mode", "-m",
        choices=["step", "failure"],
        default="step",
        help=" : step( statement )  failure(  )"
    )
    parser.add_argument(
        "--source_robot", "-s",
        default="panda",
        help="Source robot type (panda, ur5, sawyer, jaco)"
    )
    parser.add_argument(
        "--target_robot", "-r",
        default="panda",
        help="Target robot type for environment setup (panda, ur5, sawyer, jaco)"
    )
    parser.add_argument(
        "--max_steps", "-n",
        type=int,
        default=None,
        help="Maximum number of primitive skill calls (default: unlimited)"
    )
    parser.add_argument(
        "--code_source", "-c",
        type=str,
        default="static",
        choices=["static", "local_llm", "remote_llm"],
        help="Where to get task skill code: static | local_llm | remote_llm",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="code_agent",
        choices=["code_agent", "ours"],
    )
    parser.add_argument(
        "--host_ip",
        type=str,
        default="172.17.0.1",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=5000,
    )
    parser.add_argument(
        "--max_repairs",
        type=int,
        default=5,
        help="Maximum number of repair attempts before giving up (default: 5)"
    )

    args = parser.parse_args()

    task_name = args.task
    mode = args.mode
    source_robot = args.source_robot.lower()
    target_robot = args.target_robot.lower()
    max_skill_calls = args.max_steps
    max_repair_attempts = args.max_repairs
    code_source = args.code_source.lower()
    model = args.model.lower()

    print(f"[INFO] Task: {task_name}")
    print(f"[INFO] Mode: {mode}")
    print(f"[INFO] Source Robot: {source_robot}")
    print(f"[INFO] Target Robot: {target_robot}")
    print(f"[INFO] Max Skill Calls: {max_skill_calls if max_skill_calls else 'unlimited'}")

    # rlbench.tasks task class 
    tasks_module = importlib.import_module("rlbench.tasks")
    try:
        task_cls = getattr(tasks_module, task_name)
    except AttributeError:
        print(f"[ERROR] '{task_name}'     .")
        sys.exit(1)

    env = None
    task = None
    provider = None

    try:
        #   (target_robot )
        print(f"[INFO] Setting up environment with {target_robot} robot...")
        env, task = setup_environment(task_cls, target_robot)

        #  
        descriptions, _ = task.reset()

        #   JSON 
        initial_info = get_initial_info(task, task_name, target_robot, descriptions)
        output_json(initial_info)

        policy_file_path = f"./tasks_{source_robot}/{task_name}.py"

        objects_info_str, objects_info, object_names = get_object_names(task, task_name)
        grasp_guidance = initial_info.get("scene", {}).get("grasp_guidance", "")
        initial_prompt = get_prompt(task, task_name, object_names, target_robot, descriptions, model, grasp_guidance)

        # Extract just object names (strings) from objects_info for PolicyProvider
        # object_names returns (objects_info_str, objects_info) where objects_info is list of dicts
        _, objects_info = object_names if isinstance(object_names, tuple) else ("", object_names)
        object_names_list = objects_info_str

        # PolicyProvider 
        provider = PolicyProvider(
            mode=code_source,
            policy_path=policy_file_path,
            task_name=task_name,
            source_robot=source_robot,
            target_robot=target_robot,
            model=model,
            initial_prompt=initial_prompt,
            remote_host=args.host_ip,
            remote_port=args.port,
            object_names=object_names,
            descriptions=descriptions,
            target_scene_info=objects_info,
            grasp_guidance=grasp_guidance,
        )

        # LLM  /   provider   (main )
        # Note: main() EpisodeResult   result  
        class _DummyResult:
            def __init__(self):
                self.llm_calls = []
        dummy_result = _DummyResult()
        provider = _wrap_provider_with_timing(provider, task_name, lambda x: print(f"[LLM_CALL] {x}"), dummy_result)

        print(f"[INFO] PolicyProvider initialized with mode: {code_source}")

        #   
        source_code = provider.get_initial_code()
        #   
        if code_source != 'static':
            source_code = postprocess_initial_code(source_code, target_robot)
        source_lines = source_code.split('\n')
        print(f"[INFO] Loaded policy from: {policy_file_path}")

        print(source_code)
        #   
        if model == "ours":
            execute_ours_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)
        elif mode == "step":
            execute_step_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)
        else:  # failure mode
            execute_failure_mode(env, task, task_name, target_robot, source_code, source_lines, max_skill_calls, provider, max_repair_attempts)

    except KeyboardInterrupt:
        print("\n[INFO] Interrupted by user (Ctrl+C)")
        sys.exit(130)
    except Exception as e:
        print(f"[FATAL_ERROR] Unhandled exception: {type(e).__name__}: {e}")
        traceback.print_exc()
        output_json({
            "type": "fatal_error",
            "error": str(e),
            "error_type": type(e).__name__,
            "traceback": traceback.format_exc()
        })
        sys.exit(1)
    finally:
        # PolicyProvider 
        if provider is not None:
            try:
                # unwrap if wrapped
                actual_provider = getattr(provider, '_wrapped_provider', provider)
                actual_provider.close()
            except Exception as close_err:
                print(f"[WARNING] Failed to close provider: {close_err}")
        #  
        if env is not None:
            try:
                print("[INFO] Shutting down environment...")
                shutdown_environment(env)
                print("[INFO] Done.")
            except Exception as shutdown_err:
                print(f"[WARNING] Failed to shutdown environment: {shutdown_err}")


if __name__ == "__main__":
    main()
