
import os
import time
import threading
from typing import Any, Dict, List, Optional

from result_types import SubtaskResult, RepairRecord
from exceptions import SubtaskFailure, SubtaskSkip
from remote_llm_socket_client import RemoteLLMSocketClient
from scene_info import get_scene_info
from ours_utils import (
    extract_subtask_statements,
    extract_function_args_as_dict,
    basic_validate_skill_call,
    get_statement_code,
    extract_skill_call_info,
    is_primitive_skill,
    ProjectedStateTracker,
    GENESIS_PRIMITIVE_SKILLS,
)
from background_validator import (
    BackgroundValidator,
    simulate_future_code_and_extract_positions,
    validate_release_positions,
)
import ast
import re

def strip_markdown_artifacts(code: str) -> str:
    if not code:
        return ""

    text = code.strip()

    pattern = r'```(?:python)?\s*\n(.*?)```'
    matches = re.findall(pattern, text, re.DOTALL)

    if matches:
        text = matches[0].strip()
    else:
        text = re.sub(r'^```(?:python)?\s*\n?', '', text)
        text = re.sub(r'\n?```\s*$', '', text)
    lines = text.split('\n')
    cleaned_lines = []
    for line in lines:
        stripped = line.strip()
        if stripped in ('```', '```python', '```py'):
            continue
        if stripped == "obj['code']":
            continue
        if line.rstrip().endswith('```'):
            line = line.rstrip()[:-3].rstrip()
            if not line.strip():
                continue
        cleaned_lines.append(line)
    while cleaned_lines and not cleaned_lines[-1].strip():
        cleaned_lines.pop()

    return '\n'.join(cleaned_lines)

def contains_return(stmt: ast.stmt) -> bool:
    if isinstance(stmt, ast.Return):
        return True
    for child in ast.walk(stmt):
        if isinstance(child, ast.Return):
            return True
    return False

def remove_returns_from_stmt(stmt: ast.stmt) -> ast.stmt:
    class ReturnRemover(ast.NodeTransformer):
        def visit_Return(self, node):
            return ast.Pass()

    transformer = ReturnRemover()
    return transformer.visit(stmt)

def is_compound_statement(stmt: ast.stmt) -> bool:
    return isinstance(stmt, (ast.If, ast.For, ast.While, ast.Try, ast.With))

def contains_skill_call(stmt: ast.stmt, primitive_skills: set) -> bool:
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            if isinstance(node.func, ast.Name):
                if node.func.id.lower() in {s.lower() for s in primitive_skills}:
                    return True
            elif isinstance(node.func, ast.Attribute):
                if node.func.attr.lower() in {s.lower() for s in primitive_skills}:
                    return True
    return False

def flatten_compound_statement(stmt: ast.stmt) -> list:
    body = []
    if isinstance(stmt, ast.If):
        body.extend(stmt.body)
        if stmt.orelse:
            body.extend(stmt.orelse)
    elif isinstance(stmt, (ast.For, ast.While)):
        body.extend(stmt.body)
        if stmt.orelse:
            body.extend(stmt.orelse)
    elif isinstance(stmt, ast.Try):
        body.extend(stmt.body)
        for handler in stmt.handlers:
            body.extend(handler.body)
        if stmt.orelse:
            body.extend(stmt.orelse)
        if stmt.finalbody:
            body.extend(stmt.finalbody)
    elif isinstance(stmt, ast.With):
        body.extend(stmt.body)
    return body

def sha1(code: str) -> str:
    import hashlib
    return hashlib.sha1(code.encode('utf-8')).hexdigest()

class PolicyProvider:

    def __init__(
        self,
        mode: str,                 # "static" | "remote_llm"
        task_name: str,
        model: str = "code_agent",
        initial_prompt: str = "",
        remote_host: str = "127.0.0.1",
        remote_port: int = 9000,
        object_names: list = None,
        descriptions: str = "",
        scene_info: dict = None,
        source_robot: str = "panda",
        target_robot: str = "suction",
        enable_background_validation: bool = True,
        enable_revise_on_step: bool = False,
        enable_code_cache: bool = False,
        llm_call_logger: Optional[callable] = None,
    ):
        self.mode = mode
        self.task_name = task_name
        self.model = model
        self.initial_prompt = initial_prompt
        self.object_names = object_names or []
        self.descriptions = descriptions
        self.scene_info = scene_info
        self.source_robot = source_robot
        self.target_robot = target_robot
        self.enable_background_validation = enable_background_validation
        self.enable_revise_on_step = enable_revise_on_step
        self.enable_code_cache = enable_code_cache
        self.llm_call_logger = llm_call_logger
        self.last_response_num_tokens: Optional[int] = None

        # Tracking fields for LLM calls and execution metrics
        self.llm_calls: List[Dict[str, Any]] = []
        self.total_llm_tokens: int = 0
        self.total_llm_time: float = 0.0
        self.total_failure_count: int = 0
        self.total_repair_count: int = 0

        self.client: Optional[RemoteLLMSocketClient] = None
        self.bg_client: Optional[RemoteLLMSocketClient] = None
        self._bg_client_lock = threading.Lock()
        self._remote_host = remote_host
        self._remote_port = remote_port

        self.output_path = None
        self.reference_code = None

        # Cache for generated subtask code (subtask_name -> code)
        self._code_cache: Dict[str, str] = {}

        if self.mode == "remote_llm":
            self.client = RemoteLLMSocketClient(
                remote_host,
                remote_port,
                connect_timeout=30.0,
                recv_timeout=100.0
            )
            print(f"[INFO] Connecting to remote LLM server at {remote_host}:{remote_port}...")
            self.client.connect()
            print(f"[INFO] Successfully connected to {remote_host}:{remote_port}")

    def close(self):
        if self.client:
            self.client.close()
            self.client = None
        if self.bg_client:
            self.bg_client.close()
            self.bg_client = None

    # ========================================================================
    # Metrics
    # ========================================================================

    def get_metrics(self) -> Dict[str, Any]:
        return {
            "total_llm_tokens": self.total_llm_tokens,
            "total_llm_time": self.total_llm_time,
            "total_failure_count": self.total_failure_count,
            "total_repair_count": self.total_repair_count,
            "llm_calls": self.llm_calls,
        }

    def reset_metrics(self) -> None:
        self.llm_calls = []
        self.total_llm_tokens = 0
        self.total_llm_time = 0.0
        self.total_failure_count = 0
        self.total_repair_count = 0

    def _log_llm_call(
        self,
        action: str,
        subtask_name: str,
        duration: float,
        num_tokens: Optional[int] = None,
    ) -> None:
        call_info = {
            "task": self.task_name,
            "subtask": subtask_name,
            "action": action,
            "duration_sec": duration,
        }
        if num_tokens is not None:
            call_info["num_tokens"] = num_tokens
            self.total_llm_tokens += num_tokens
        self.total_llm_time += duration
        self.llm_calls.append(call_info)

        # Call external logger (for CSV logging)
        if self.llm_call_logger is not None:
            try:
                self.llm_call_logger(call_info)
            except Exception as e:
                print(f"[WARN] Failed to log LLM call to external logger: {e}")

    # ========================================================================
    # LLM Communication: Generate & Repair
    # ========================================================================

    def generate_subtask_code(
        self,
        subtask_name: str,
        obj_name: Optional[str] = None,
        target_name: Optional[str] = None,
        scene_info: Optional[Dict[str, Any]] = None,
    ) -> str:
        print(f"[DEBUG] generate_subtask_code called: subtask={subtask_name}, obj={obj_name}, target={target_name}")
        print(f"[DEBUG]   mode: {self.mode}")

        if self.mode == "static":
            subtask_path = f"./subtasks_{self.target_robot}/{subtask_name}.py"
            with open(subtask_path, "r", encoding="utf-8") as f:
                return f.read()

        # Check cache first (only if enabled)
        if self.enable_code_cache and subtask_name in self._code_cache:
            print(f"[CACHE HIT] Using cached code for {subtask_name}")
            return self._code_cache[subtask_name]

        assert self.client is not None

        reference_path = f"./subtasks_{self.source_robot}/{subtask_name}.py"
        reference_code = ""
        try:
            with open(reference_path, "r", encoding="utf-8") as f:
                full_code = f.read()
            # Extract only the main function (matching subtask_name), not helper functions
            reference_code = self._extract_main_function(full_code, subtask_name)
        except FileNotFoundError:
            print(f"[WARN] Reference subtask not found: {reference_path}")

        print(f"[CACHE MISS] Generating code for {subtask_name}")
        resp = self.client.request({
            "type": "generate",
            "task_name": self.task_name,
            "skill_name": subtask_name,
            "simulator": "genesis",
            "robot_type": self.target_robot,
            "source_robot": self.source_robot,
            "target_robot": self.target_robot,
            "prompt": self.initial_prompt,
            "reference_code": reference_code,
            "available_objects": self.object_names,
            "descriptions": self.descriptions,
            "target_scene_info": (scene_info or self.scene_info or {}).get("objects", []),
            "obj_name": obj_name,
            "target_name": target_name,
        })

        print(f"[DEBUG] generate response type: {resp.get('type')}, ok: {resp.get('ok')}")
        if resp.get("type") != "generate_result" or not resp.get("ok", False):
            print(f"[DEBUG] generate FAILED: {resp}")
            raise RuntimeError(f"Remote subtask generation failed: {resp}")

        code = resp.get("code", "")
        if not code.strip():
            print(f"[DEBUG] generate returned EMPTY code")
            raise RuntimeError(f"Remote subtask generation returned empty code for {subtask_name}")

        # Log received code from server
        print(f"[GENERATE] Received code from server for {subtask_name}:")
        print(code[:800] + "..." if len(code) > 800 else code)

        self.last_response_num_tokens = resp.get("num_tokens")
        print(f"[DEBUG] generate returned code length: {len(code)}, tokens: {self.last_response_num_tokens}")
        print(f"[POLICY_PROVIDER] Subtask code(before preprocessing):\n{code}")

        # Cache the generated code (only if enabled)
        if self.enable_code_cache:
            self._code_cache[subtask_name] = code
            print(f"[CACHE SAVE] Cached code for {subtask_name}")

        return code

    def repair_subtask(
        self,
        subtask_name: str,
        obj_name: Optional[str],
        target_name: Optional[str],
        error_info: Dict[str, Any],
        scene: Dict[str, Any],
        task_name: str,
    ) -> Optional[str]:
        print(f"[DEBUG] repair_subtask called for {subtask_name}")
        print(f"[DEBUG]   obj_name: {obj_name}, target_name: {target_name}")
        print(f"[DEBUG]   error_info: {error_info}")
        print(f"[DEBUG]   scene keys: {list(scene.keys()) if isinstance(scene, dict) else type(scene)}")

        if self.mode == "static":
            print(f"[DEBUG] repair_subtask: static mode, raising error")
            raise RuntimeError("static mode cannot repair subtasks. Use remote_llm.")

        assert self.client is not None

        subtask_path = f"./subtasks_{self.target_robot}/{subtask_name}.py"
        current_subtask_code = ""
        try:
            with open(subtask_path, "r", encoding="utf-8") as f:
                current_subtask_code = f.read()
        except FileNotFoundError:
            print(f"[WARN] Subtask file not found: {subtask_path}")

        # Clean up markdown artifacts before sending to server
        current_subtask_code = strip_markdown_artifacts(current_subtask_code)

        # Get reference code from source robot for repair hints
        reference_code = self._read_reference(self.source_robot, subtask_name)

        resp = self.client.request({
            "type": "repair",
            "task_name": task_name,
            "skill_name": subtask_name,
            "simulator": "genesis",
            "robot_type": self.target_robot,
            "source_robot": self.source_robot,
            "target_robot": self.target_robot,
            "model": self.model,
            "meta": {"model": self.model, "source_robot": self.source_robot},
            "constraints": {
                "must_keep_executed_prefix": False,
                "executed_prefix_count": 0,
            },
            "inputs": {
                "current_code": current_subtask_code,
                "reference_code": reference_code,
                "executed_stmt_texts": [],
                "failed_stmt_text": f"subtask {subtask_name}(obj={obj_name}, target={target_name}) failed",
                "error": error_info,
                "scene": scene,
            }
        })

        print(f"[DEBUG] repair response type: {resp.get('type')}, ok: {resp.get('ok')}")
        if resp.get("type") != "repair_result" or not resp.get("ok", False):
            print(f"[WARN] Remote subtask repair failed: {resp}")
            return None

        new_code = resp.get("code", "")
        if not new_code.strip():
            print("[WARN] Remote subtask repair returned empty code.")
            return None

        print(f"[DEBUG] repair returned code length: {len(new_code)}")
        print(f"[POLICY_PROVIDER] Repaired Subtask code(before preprocessing):\n{new_code}")
        self.last_response_num_tokens = resp.get("num_tokens")
        print(f"[DEBUG] repair num_tokens: {self.last_response_num_tokens}")
        self._save_subtask(subtask_name, new_code)
        return new_code

    def revise_subtask(
        self,
        subtask_name: str,
        current_code: str,
        executed_stmt_texts: List[str],
        scene: Dict[str, Any],
        last_step: Dict[str, Any],
    ) -> Optional[str]:
        print(f"[DEBUG] revise_subtask called for {subtask_name}")
        print(f"[DEBUG]   current_code length: {len(current_code)}")
        print(f"[DEBUG]   executed_stmt_texts count: {len(executed_stmt_texts)}")
        print(f"[DEBUG]   last_step: {last_step}")

        if self.mode == "static":
            print(f"[DEBUG] revise_subtask: static mode, returning None")
            return None

        assert self.client is not None

        print(f"[DEBUG] Sending revise request to remote LLM...")
        resp = self.client.request({
            "type": "revise",
            "task_name": self.task_name,
            "skill_name": subtask_name,
            "simulator": "genesis",
            "robot_type": self.target_robot,
            "source_robot": self.source_robot,
            "target_robot": self.target_robot,
            "model": self.model,
            "meta": {"model": self.model, "source_robot": self.source_robot},
            "constraints": {
                "must_keep_executed_prefix": True,
                "executed_prefix_count": len(executed_stmt_texts),
            },
            "inputs": {
                "current_code": current_code,
                "executed_stmt_texts": executed_stmt_texts,
                "scene": scene,
                "last_step": last_step,
            }
        })

        print(f"[DEBUG] revise response type: {resp.get('type')}, ok: {resp.get('ok')}")
        if resp.get("type") != "revise_result" or not resp.get("ok", False):
            print(f"[WARN] Remote subtask revise failed: {resp}")
            return None

        new_code = resp.get("code", "")
        if not new_code.strip():
            print("[WARN] Remote subtask revise returned empty code.")
            return None

        print(f"[DEBUG] revise returned code length: {len(new_code)}")
        print(f"[POLICY_PROVIDER] Revised Subtask code(before preprocessing):\n{new_code}")

        self.last_response_num_tokens = resp.get("num_tokens")
        print(f"[DEBUG] revise num_tokens: {self.last_response_num_tokens}")
        return new_code

    def batch_invalid_repair(
        self,
        current_code: str,
        executed_stmt_texts: List[str],
        invalid_statements: List[Dict[str, Any]],
        scene: Dict[str, Any],
        subtask_name: str = "",
    ) -> Optional[str]:
        print(f"[DEBUG] batch_invalid_repair called for {subtask_name}")
        print(f"[DEBUG]   current_code length: {len(current_code)}")
        print(f"[DEBUG]   executed_stmt_texts count: {len(executed_stmt_texts)}")
        print(f"[DEBUG]   invalid_statements count: {len(invalid_statements)}")
        for i, stmt in enumerate(invalid_statements):
            print(f"[DEBUG]   invalid[{i}]: lines {stmt.get('start_line')}-{stmt.get('end_line')}, skill={stmt.get('skill_name')}, violations={stmt.get('violations')}")

        if self.mode != "remote_llm":
            print(f"[DEBUG] batch_invalid_repair: not remote_llm mode, returning None")
            return None

        if not invalid_statements:
            print(f"[DEBUG] batch_invalid_repair: no invalid statements, returning None")
            return None

        repair_start = time.time()

        with self._bg_client_lock:
            if self.bg_client is None or getattr(self.bg_client, '_sock', None) is None:
                self.bg_client = RemoteLLMSocketClient(
                    self._remote_host,
                    self._remote_port,
                    connect_timeout=30.0,
                    recv_timeout=100.0
                )
                print(f"[INFO] Connecting background client to {self._remote_host}:{self._remote_port}...")
                try:
                    self.bg_client.connect()
                    print(f"[INFO] Background client connected successfully")
                except Exception as e:
                    print(f"[ERROR] Failed to connect background client: {e}")
                    self.bg_client = None
                    return None

            try:
                resp = self.bg_client.request({
                    "type": "batch_invalid_repair",
                    "task_name": self.task_name,
                    "skill_name": subtask_name,
                    "simulator": "genesis",
                    "robot_type": self.target_robot,
                    "source_robot": self.source_robot,
                    "target_robot": self.target_robot,
                    "model": self.model,
                    "meta": {"model": self.model},
                    "constraints": {
                        "must_keep_executed_prefix": True,
                        "executed_prefix_count": len(executed_stmt_texts),
                    },
                    "inputs": {
                        "current_code": current_code,
                        "executed_stmt_texts": executed_stmt_texts,
                        "invalid_statements": invalid_statements,
                        "scene": scene,
                    }
                })
            except (OSError, ConnectionError) as e:
                print(f"[BATCH_INVALID_REPAIR] Socket error, closing bg_client: {e}")
                if self.bg_client:
                    self.bg_client.close()
                self.bg_client = None
                repair_duration = time.time() - repair_start
                self._log_llm_call("bg_validation_repair_failed", subtask_name, repair_duration, None)
                return None

        repair_duration = time.time() - repair_start
        print(f"[DEBUG] batch_invalid_repair response type: {resp.get('type')}, ok: {resp.get('ok')}")
        print(f"[DEBUG] batch_invalid_repair response full: {resp}")
        if resp.get("type") != "batch_invalid_repair_result" or not resp.get("ok", False):
            print(f"[BATCH_INVALID_REPAIR] Remote batch repair failed: {resp}")
            self._log_llm_call("bg_validation_repair_failed", subtask_name, repair_duration, resp.get("num_tokens"))
            return None

        repaired_statement = resp.get("code", "")
        if not repaired_statement.strip():
            print(f"[DEBUG] batch_invalid_repair: empty code returned")
            self._log_llm_call("bg_validation_repair_empty", subtask_name, repair_duration, resp.get("num_tokens"))
            return None

        self.last_response_num_tokens = resp.get("num_tokens")
        print(f"[DEBUG] batch_invalid_repair num_tokens: {self.last_response_num_tokens}")

        # Get line numbers from response or from first invalid statement
        start_line = resp.get("start_line")
        end_line = resp.get("end_line")
        if start_line is None and invalid_statements:
            start_line = invalid_statements[0].get("start_line")
            end_line = invalid_statements[0].get("end_line", start_line)

        if start_line is None:
            print(f"[DEBUG] batch_invalid_repair: no line number info, cannot replace")
            return None

        # Replace the specific lines in original code
        lines = current_code.split('\n')
        print(f"[DEBUG] batch_invalid_repair: replacing lines {start_line}-{end_line}")
        print(f"[DEBUG]   original line(s): {lines[start_line-1:end_line]}")
        print(f"[DEBUG]   replacement: {repaired_statement}")

        # Preserve indentation from original line (handles both tabs and spaces)
        original_line = lines[start_line - 1]
        indent = original_line[:len(original_line) - len(original_line.lstrip())]

        # Handle multi-line replacement: apply indent to each line
        replacement_lines = repaired_statement.strip().split('\n')
        indented_lines = []
        for i, line in enumerate(replacement_lines):
            if i == 0:
                # First line gets the original indentation
                indented_lines.append(indent + line.strip())
            else:
                # Subsequent lines: preserve their relative indentation
                stripped = line.lstrip()
                if stripped:
                    # Calculate relative indent from first line of replacement
                    first_line_indent = len(replacement_lines[0]) - len(replacement_lines[0].lstrip())
                    current_indent = len(line) - len(stripped)
                    relative_indent = max(0, current_indent - first_line_indent)
                    indented_lines.append(indent + ' ' * relative_indent + stripped)
                else:
                    indented_lines.append('')

        # Replace lines (start_line and end_line are 1-indexed)
        lines[start_line - 1:end_line] = indented_lines
        new_code = '\n'.join(lines)

        if sha1(new_code) == sha1(current_code):
            print(f"[DEBUG] batch_invalid_repair: new code same as current (hash match)")
            self._log_llm_call("bg_validation_repair_no_change", subtask_name, repair_duration, self.last_response_num_tokens)
            return None

        print(f"[DEBUG] batch_invalid_repair: code updated, new length: {len(new_code)}")
        self._log_llm_call("bg_validation_repair", subtask_name, repair_duration, self.last_response_num_tokens)
        return new_code

    # ========================================================================
    # File I/O
    # ========================================================================

    def _save_subtask(self, subtask_name: str, code: str) -> None:
        # First, clean up markdown artifacts
        code = strip_markdown_artifacts(code)

        # Fix common indentation issues where 'def' lines have leading spaces
        lines = code.split('\n')
        fixed_lines = []
        for line in lines:
            # If line starts with spaces followed by 'def ', move to column 0
            stripped = line.lstrip()
            if stripped.startswith('def ') and line != stripped:
                # This is an indented def statement - move it to column 0
                print(f"[FIX] Correcting indentation for: {stripped[:50]}...")
                fixed_lines.append(stripped)
            else:
                fixed_lines.append(line)
        code = '\n'.join(fixed_lines)

        subtask_path = f"./subtasks_{self.target_robot}/{subtask_name}.py"
        os.makedirs(os.path.dirname(subtask_path), exist_ok=True)
        with open(subtask_path, "w", encoding="utf-8") as f:
            f.write(code)
        print(f"[INFO] Saved repaired subtask to: {subtask_path}")

    def _save(self, code: str) -> None:
        os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
        with open(self.output_path, "w", encoding="utf-8") as f:
            f.write(code)

    def _read_reference(self, source_robot: str, subtask_name: str) -> str:
        path = f"./subtasks_{source_robot}/{subtask_name}.py"
        try:
            with open(path, "r", encoding="utf-8") as f:
                return f.read()
        except FileNotFoundError:
            print(f"[WARN] Reference code not found: {path}")
            return ""

    def _extract_main_function(self, code: str, function_name: str) -> str:
        import ast
        try:
            tree = ast.parse(code)
            # Find the function with matching name
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef) and node.name == function_name:
                    # Get the source code for this function only
                    lines = code.split('\n')
                    start_line = node.lineno - 1
                    end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(lines)

                    # Extract function code
                    func_lines = lines[start_line:end_line]
                    return '\n'.join(func_lines)

            # If function not found, return full code (fallback)
            print(f"[WARN] Function '{function_name}' not found in reference, using full code")
            return code
        except Exception as e:
            print(f"[WARN] Failed to extract function '{function_name}': {e}, using full code")
            return code

    def _extract_helper_functions(self, code: str, main_function_name: str) -> Dict[str, callable]:
        import ast
        import numpy as np
        import skill_code

        helper_funcs = {}
        try:
            tree = ast.parse(code)
            lines = code.split('\n')

            # Collect all function definitions except the main one
            for node in ast.iter_child_nodes(tree):
                if isinstance(node, ast.FunctionDef) and node.name != main_function_name:
                    start_line = node.lineno - 1
                    end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(lines)
                    func_code = '\n'.join(lines[start_line:end_line])

                    # Compile and execute the helper function
                    import math
                    import time as time_module
                    helper_context = {
                        'np': np,
                        'numpy': np,
                        'math': math,
                        'time': time_module,
                        # Common skills
                        'move_gripper_to': skill_code.move_gripper_to,
                        'move_to_position': skill_code.move_to_position,
                        'move_parallel': skill_code.move_parallel,
                        'rotate_gripper': skill_code.rotate_gripper,
                        # Panda
                        'open_gripper': getattr(skill_code, 'open_gripper', None),
                        'close_gripper': getattr(skill_code, 'close_gripper', None),
                        'pick': getattr(skill_code, 'pick', None),
                        'place': getattr(skill_code, 'place', None),
                        'grasp_handle': getattr(skill_code, 'grasp_handle', None),
                        'release_handle': getattr(skill_code, 'release_handle', None),
                        'open_panda': getattr(skill_code, 'open_gripper', None),
                        'close_panda': getattr(skill_code, 'close_gripper', None),
                        # Robotiq85
                        'open_robotiq85': getattr(skill_code, 'open_robotiq85', None),
                        'close_robotiq85': getattr(skill_code, 'close_robotiq85', None),
                        'pick_robotiq85': getattr(skill_code, 'pick_robotiq85', None),
                        'place_robotiq85': getattr(skill_code, 'place_robotiq85', None),
                        'grasp_handle_robotiq85': getattr(skill_code, 'grasp_handle_robotiq85', None),
                        'release_handle_robotiq85': getattr(skill_code, 'release_handle_robotiq85', None),
                        # Suction
                        'activate_vacuum': getattr(skill_code, 'activate_vacuum', None),
                        'deactivate_vacuum': getattr(skill_code, 'deactivate_vacuum', None),
                        'attach_vacuum_handle': getattr(skill_code, 'attach_vacuum_handle', None),
                        'detach_vacuum_handle': getattr(skill_code, 'detach_vacuum_handle', None),
                        'activate_suction': getattr(skill_code, 'activate_vacuum', None),
                        'deactivate_suction': getattr(skill_code, 'deactivate_vacuum', None),
                        'attach_suction': getattr(skill_code, 'attach_vacuum_handle', None),
                        'detach_suction': getattr(skill_code, 'detach_vacuum_handle', None),
                    }
                    exec(func_code, helper_context)
                    if node.name in helper_context:
                        helper_funcs[node.name] = helper_context[node.name]
                        print(f"[HELPER] Loaded helper function: {node.name}")

        except Exception as e:
            print(f"[WARN] Failed to extract helper functions: {e}")

        return helper_funcs

    # ========================================================================
    # Scene Info
    # ========================================================================

    def _get_scene_info(self, env) -> Dict[str, Any]:
        return get_scene_info(env, self.target_robot)

    # ========================================================================
    # Subtask Execution
    # ========================================================================

    def execute_subtask(
        self,
        subtask_name: str,
        env,
        obj_name: Optional[str] = None,
        target_name: Optional[str] = None,
        check_func: Optional[callable] = None,
        max_repairs: int = 5,
        **kwargs,
    ) -> "SubtaskResult":
        print(f"\n{'='*60}")
        print(f"[DEBUG] execute_subtask ENTRY")
        print(f"[DEBUG]   subtask_name: {subtask_name}")
        print(f"[DEBUG]   obj_name: {obj_name}")
        print(f"[DEBUG]   target_name: {target_name}")
        print(f"[DEBUG]   check_func: {'provided' if check_func is not None else 'None'}")
        print(f"[DEBUG]   max_repairs: {max_repairs}")
        print(f"[DEBUG]   model: {self.model}")
        print(f"[DEBUG]   mode: {self.mode}")
        print(f"[DEBUG]   kwargs: {list(kwargs.keys())}")
        print(f"{'='*60}")

        if self.model == "ours":
            return self._execute_subtask_ours_mode(
                subtask_name=subtask_name,
                env=env,
                obj_name=obj_name,
                target_name=target_name,
                check_func=check_func,
                max_repairs=max_repairs,
                **kwargs,
            )
        else:
            return self._execute_subtask_on_failure(
                subtask_name=subtask_name,
                env=env,
                obj_name=obj_name,
                target_name=target_name,
                check_func=check_func,
                max_repairs=max_repairs,
                **kwargs,
            )

    def _execute_subtask_ours_mode(
        self,
        subtask_name: str,
        env,
        obj_name: Optional[str] = None,
        target_name: Optional[str] = None,
        check_func: Optional[callable] = None,
        max_repairs: int = 5,
        **kwargs,
    ) -> "SubtaskResult":
        print(f"[DEBUG] _execute_subtask_ours_mode called for {subtask_name}")
        print(f"[DEBUG]   obj_name: {obj_name}, target_name: {target_name}")
        print(f"[DEBUG]   check_func: {'provided' if check_func is not None else 'None'}")
        print(f"[DEBUG]   max_repairs: {max_repairs}")
        print(f"[DEBUG]   enable_background_validation: {self.enable_background_validation}")
        print(f"[DEBUG]   enable_revise_on_step: {self.enable_revise_on_step}")

        result = SubtaskResult(
            subtask_name=subtask_name,
            obj_name=obj_name,
            target_name=target_name,
        )

        scene_info = self._get_scene_info(env)
        print(f"[DEBUG] Initial scene_info keys: {list(scene_info.keys()) if isinstance(scene_info, dict) else type(scene_info)}")

        # Generate subtask code
        start_time = time.time()
        try:
            subtask_code = self.generate_subtask_code(
                subtask_name=subtask_name,
                obj_name=obj_name,
                target_name=target_name,
                scene_info=scene_info,
            )
            result.generated_code = subtask_code
            gen_duration = time.time() - start_time
            self._log_llm_call("generate", subtask_name, gen_duration, self.last_response_num_tokens)

            # Save generated code to file
            self._save_subtask(subtask_name, subtask_code)

        except Exception as e:
            gen_duration = time.time() - start_time
            self._log_llm_call("generate_failed", subtask_name, gen_duration, None)
            result.error = f"Failed to generate subtask code: {e}"
            self.total_failure_count += 1
            raise SubtaskFailure(
                subtask_name=subtask_name,
                obj_name=obj_name,
                target_name=target_name,
                message=result.error,
            )

        repair_count = 0
        bg_repair_count = 0  # Separate counter for background validation repairs
        current_code = subtask_code
        original_code = subtask_code
        repair_history: List[RepairRecord] = []  # Track all repair attempts
        executed_stmt_texts: List[str] = []
        skill_call_count = 0

        # Initialize ProjectedStateTracker
        print(f"[DEBUG] Initializing ProjectedStateTracker...")
        state_tracker = ProjectedStateTracker()
        state_tracker.initialize_from_env(env)
        print(f"[DEBUG] Initial state: gripper_open={state_tracker.confirmed_state.gripper.is_open}, held_obj={state_tracker.confirmed_state.gripper.held_object}")

        # Create execution context for statement-by-statement execution
        # Try to get task from kwargs or from env
        task = kwargs.get('task', None)
        if task is None and hasattr(env, 'env') and hasattr(env.env, 'task'):
            task = env.env.task
        context = self._create_subtask_execution_context(env, subtask_name, task=task)

        # Add function parameters to context (critical for statement-by-statement execution)
        context['obj_name'] = obj_name
        context['target_name'] = target_name
        # Add common default parameters that may be used in generated code
        context['pointing_to'] = kwargs.get('pointing_to', 'down')
        context['height_offset'] = kwargs.get('height_offset', 0.1)
        for key, value in kwargs.items():
            context[key] = value

        # Extract function arguments from generated code and register them in context
        # This ensures all parameters (including those with default values) are available
        func_args = extract_function_args_as_dict(current_code, subtask_name)
        for arg_name, default_value in func_args.items():
            if arg_name not in context:
                # Only add if not already set (preserve explicitly passed values)
                if default_value is not None and not str(default_value).startswith('<'):
                    context[arg_name] = default_value
                    print(f"[DEBUG] Registered function arg: {arg_name}={default_value}")
                else:
                    print(f"[DEBUG] Function arg '{arg_name}' has no default, needs explicit value")

        # Thread synchronization for background validation
        validation_lock = threading.Lock()
        pending_repair_code: Optional[str] = None
        pending_repair_violations: List[str] = []
        pending_repair_duration: Optional[float] = None  # LLM call duration
        pending_repair_num_tokens: Optional[int] = None

        def apply_pending_repair() -> bool:
            nonlocal current_code, pending_repair_code, pending_repair_violations, bg_repair_count
            nonlocal pending_repair_duration, pending_repair_num_tokens
            with validation_lock:
                if pending_repair_code is not None:
                    print(f"[VALIDATION_REPAIR] Applying repaired code from background validation")
                    # Record repair in history (background validation is separate from repair_count)
                    repair_record = RepairRecord(
                        repair_type="background_validation",
                        violations=pending_repair_violations.copy(),
                        code_before=current_code,
                        code_after=pending_repair_code,
                        statement_index=len(executed_stmt_texts),
                        duration_sec=pending_repair_duration,
                        num_tokens=pending_repair_num_tokens,
                    )
                    repair_history.append(repair_record)
                    bg_repair_count += 1

                    current_code = pending_repair_code
                    pending_repair_code = None
                    pending_repair_violations = []
                    pending_repair_duration = None
                    pending_repair_num_tokens = None
                    state_tracker.reset_projection()
                    return True
            return False

        def background_validate_future_statements(
            future_stmts: List[ast.stmt],
            src_lines: List[str],
            exec_stmts: List[str],
            start_idx: int,
            pre_captured_scene: Dict[str, Any],
            current_stmt_info: Optional[Dict[str, Any]] = None,
            subtask_kwargs: Optional[Dict[str, Any]] = None,
        ):
            nonlocal pending_repair_code, pending_repair_violations
            nonlocal pending_repair_duration, pending_repair_num_tokens

            try:
                # Create a copy of tracker state for background thread
                tracker_copy = ProjectedStateTracker()
                if state_tracker.confirmed_state:
                    tracker_copy.confirmed_state = state_tracker.confirmed_state.copy()
                    tracker_copy.projected_state = state_tracker.confirmed_state.copy()

                # Apply current statement's effect first (assuming success)
                if current_stmt_info:
                    curr_skill = current_stmt_info.get("skill_name")
                    if curr_skill and is_primitive_skill(curr_skill):
                        curr_target_pos = current_stmt_info.get("target_pos")
                        curr_target_obj = current_stmt_info.get("target_obj")
                        tracker_copy.apply_projection(curr_skill, curr_target_pos, curr_target_obj)
                        print(f"[BG_VALIDATE] Applied current statement effect: {curr_skill}")

                # Validate future statements
                invalid_statements: List[Dict[str, Any]] = []

                # ================================================================
                # Phase 1: AST-based validation (existing logic)
                # ================================================================
                for offset, stmt in enumerate(future_stmts):
                    skill_name, call_node = extract_skill_call_info(stmt)

                    if skill_name is None or not is_primitive_skill(skill_name):
                        continue

                    stmt_code = get_statement_code(stmt, src_lines)
                    stmt_index = start_idx + offset

                    # Validate preconditions on projected state
                    validation_result = tracker_copy.validate_preconditions_on_projected(
                        skill_name,
                        target_pos=None,
                        target_obj=None,
                    )

                    if not validation_result.get("success", True):
                        print(f"[BG_VALIDATE] Statement {stmt_index} ({skill_name}) INVALID: {validation_result.get('violations', [])}")
                        invalid_statements.append({
                            "start_line": stmt.lineno,
                            "end_line": getattr(stmt, "end_lineno", stmt.lineno),
                            "statement_code": stmt_code,
                            "skill_name": skill_name,
                            "violations": validation_result.get("violations", []),
                            "warnings": validation_result.get("warnings", []),
                        })

                    # Apply effect for next iteration
                    tracker_copy.apply_projection(skill_name, None, None)

                # ================================================================
                # Phase 2: Simulation-based validation for release patterns
                # ================================================================
                try:
                    # Get currently held object from scene
                    gripper_info = pre_captured_scene.get("gripper", {})
                    held_object = gripper_info.get("held_object")

                    # Simulate code execution to get actual position values
                    simulated_calls = simulate_future_code_and_extract_positions(
                        current_code,
                        subtask_name,
                        pre_captured_scene,
                        subtask_kwargs,
                    )

                    if simulated_calls:
                        # Check for release position collisions
                        release_violations = validate_release_positions(
                            simulated_calls,
                            pre_captured_scene,
                            held_object,
                            subtask_kwargs=subtask_kwargs,
                        )

                        for viol in release_violations:
                            print(f"[BG_VALIDATE] Release position collision: {viol['violation']}")
                            # Find the corresponding statement
                            for offset, stmt in enumerate(future_stmts):
                                skill_name, _ = extract_skill_call_info(stmt)
                                if skill_name == viol["skill_name"]:
                                    stmt_code = get_statement_code(stmt, src_lines)
                                    invalid_statements.append({
                                        "start_line": stmt.lineno,
                                        "end_line": getattr(stmt, "end_lineno", stmt.lineno),
                                        "statement_code": stmt_code,
                                        "skill_name": viol["skill_name"],
                                        "violation_type": "collision",
                                        "violations": [viol["violation"]],
                                        "warnings": [],
                                        "collision_info": {
                                            "target_pos": viol["target_pos"],
                                            "next_skill": viol["next_skill"],
                                            "alternative_pos": viol.get("alternative_pos"),
                                        },
                                    })
                                    break

                except Exception as sim_error:
                    print(f"[BG_VALIDATE] Simulation-based validation error: {sim_error}")
                    import traceback
                    traceback.print_exc()

                # Log validation summary
                print(f"[BG_VALIDATE] === Validation Summary for {subtask_name} ===")
                print(f"[BG_VALIDATE]   Subtask kwargs: {subtask_kwargs}")
                print(f"[BG_VALIDATE]   Total invalid statements: {len(invalid_statements)}")
                for inv in invalid_statements:
                    print(f"[BG_VALIDATE]   - Lines {inv.get('start_line')}-{inv.get('end_line')}: {inv.get('skill_name')}")
                    for v in inv.get('violations', []):
                        print(f"[BG_VALIDATE]       {v}")

                # ================================================================
                # Phase 3: Batch repair if invalid statements found
                # ================================================================
                if invalid_statements and self.mode == "remote_llm":
                    print(f"[BG_VALIDATE] Found {len(invalid_statements)} invalid statements, requesting batch repair")
                    for inv in invalid_statements:
                        print(f"  - {inv['skill_name']}: {inv['violations']}")

                    bg_repair_start = time.time()
                    new_code = self.batch_invalid_repair(
                        current_code=current_code,
                        executed_stmt_texts=exec_stmts,
                        invalid_statements=invalid_statements,
                        scene=pre_captured_scene,
                        subtask_name=subtask_name,
                    )
                    bg_repair_duration = time.time() - bg_repair_start
                    bg_repair_tokens = self.last_response_num_tokens

                    if new_code:
                        # Collect all violations for repair history
                        all_violations = []
                        for inv in invalid_statements:
                            for v in inv.get('violations', []):
                                all_violations.append(f"{inv['skill_name']}: {v}")

                        with validation_lock:
                            pending_repair_code = new_code
                            pending_repair_violations = all_violations
                            pending_repair_duration = bg_repair_duration
                            pending_repair_num_tokens = bg_repair_tokens
                        print(f"[BG_VALIDATE] Batch repair successful (duration={bg_repair_duration:.2f}s, tokens={bg_repair_tokens})")

            except Exception as e:
                print(f"[BG_VALIDATE] Error during background validation: {e}")

        # Main execution loop
        loop_iteration = 0
        while repair_count <= max_repairs:
            loop_iteration += 1
            print(f"\n[DEBUG] === Main loop iteration {loop_iteration} ===")
            print(f"[DEBUG] repair_count: {repair_count}/{max_repairs}")
            print(f"[DEBUG] executed_stmt_texts count: {len(executed_stmt_texts)}")

            # Check for pending repair from background validation
            if apply_pending_repair():
                print(f"[DEBUG] Applied pending repair, resetting executed_stmt_texts")
                executed_stmt_texts = []
                continue

            # Parse current code to get statements
            statements, src_lines = extract_subtask_statements(current_code, subtask_name)
            print(f"[DEBUG] Parsed {len(statements)} statements from current code")

            if not statements:
                print(f"[WARN] No statements found in subtask {subtask_name}")
                break

            # Find next statement to execute
            next_index = len(executed_stmt_texts)
            print(f"[DEBUG] next_index: {next_index}, total statements: {len(statements)}")

            if next_index >= len(statements):
                # All statements executed successfully
                print(f"[DEBUG] All {len(statements)} statements executed. next_index={next_index}")
                print(f"[DEBUG] executed_stmt_texts count: {len(executed_stmt_texts)}")
                print(f"[DEBUG] check_func is {'provided' if check_func is not None else 'None'}")

                if check_func is not None:
                    print(f"[DEBUG] Calling check_func for {subtask_name}(obj={obj_name}, target={target_name})...")
                    check_passed = check_func(env, obj_name, target_name)
                    print(f"[DEBUG] check_func returned: {check_passed} (type: {type(check_passed).__name__})")

                    if not check_passed:
                        print(f"[DEBUG] check_func FAILED - triggering revise")
                        # Task not complete, request more code via revise
                        if self.mode == "remote_llm":
                            scene = self._get_scene_info(env)
                            print(f"[DEBUG] Scene info for revise: {list(scene.keys()) if isinstance(scene, dict) else type(scene)}")
                            last_step = {
                                "step_index": next_index,
                                "skill_call_count": skill_call_count,
                                "statement": executed_stmt_texts[-1] if executed_stmt_texts else "",
                                "success": False,
                                "need_more_steps": True,
                            }
                            print(f"[DEBUG] last_step for revise: {last_step}")
                            revise_start = time.time()
                            new_code = self.revise_subtask(
                                subtask_name=subtask_name,
                                current_code=current_code,
                                executed_stmt_texts=executed_stmt_texts,
                                scene=scene,
                                last_step=last_step,
                            )
                            revise_duration = time.time() - revise_start
                            revise_tokens = self.last_response_num_tokens
                            self._log_llm_call("revise_on_step", subtask_name, revise_duration, revise_tokens)

                            if new_code:
                                print(f"[DEBUG] revise_subtask returned new code (len={len(new_code)})")
                                # Record repair in history
                                repair_record = RepairRecord(
                                    repair_type="revise_on_step",
                                    violations=["Check failed after all statements executed"],
                                    code_before=current_code,
                                    code_after=new_code,
                                    statement_index=len(executed_stmt_texts),
                                    duration_sec=revise_duration,
                                    num_tokens=revise_tokens,
                                )
                                repair_history.append(repair_record)
                                repair_count += 1
                                current_code = new_code
                                result.generated_code = new_code
                                continue
                            else:
                                print(f"[DEBUG] revise_subtask returned None/empty")

                        self.total_failure_count += 1
                        raise SubtaskFailure(
                            subtask_name=subtask_name,
                            obj_name=obj_name,
                            target_name=target_name,
                            message=f"Check failed for {subtask_name} after all statements",
                        )
                else:
                    print(f"[DEBUG] No check_func provided, marking success")

                print(f"[DEBUG] Subtask {subtask_name} completed successfully. repair_count={repair_count}, bg_repair_count={bg_repair_count}")
                result.success = True
                result.repair_count = repair_count
                result.bg_repair_count = bg_repair_count
                result.original_code = original_code
                result.repair_history = repair_history
                return result

            stmt = statements[next_index]
            stmt_code = get_statement_code(stmt, src_lines)
            skill_name, call_node = extract_skill_call_info(stmt)
            is_skill = skill_name is not None and is_primitive_skill(skill_name)

            # Check if this is a compound statement with skill calls inside
            is_compound = is_compound_statement(stmt)
            compound_has_skill = is_compound and contains_skill_call(stmt, GENESIS_PRIMITIVE_SKILLS)

            print(f"[DEBUG] Statement {next_index}: type={type(stmt).__name__}, is_skill={is_skill}, skill_name={skill_name}")
            print(f"[DEBUG]   is_compound={is_compound}, compound_has_skill={compound_has_skill}")
            print(f"[DEBUG]   code preview: {stmt_code[:100]}...")

            # Start background validation for future statements
            # Trigger on direct skill call OR compound statement with skill calls
            future_statements = statements[next_index + 1:]
            validation_thread = None
            should_validate = (
                self.enable_background_validation
                and (is_skill or compound_has_skill)
                and self.mode == "remote_llm"
                and len(future_statements) > 0
            )

            if should_validate:
                pre_captured_scene = self._get_scene_info(env)

                current_stmt_info = None
                if is_skill:
                    current_stmt_info = {"skill_name": skill_name}
                elif compound_has_skill:
                    current_stmt_info = {"skill_name": "compound_with_skills"}

                validation_thread = threading.Thread(
                    target=background_validate_future_statements,
                    args=(
                        future_statements,
                        src_lines.copy(),
                        executed_stmt_texts.copy(),
                        next_index + 1,
                        pre_captured_scene,
                        current_stmt_info,
                        {"obj_name": obj_name, "target_name": target_name},
                    ),
                    daemon=True,
                )
                validation_thread.start()
                print(f"[BG_VALIDATE] Started validating {len(future_statements)} future statements (obj_name={obj_name}, target_name={target_name})")

            # Skip pure return statements
            if isinstance(stmt, ast.Return):
                print(f"[STEP {next_index + 1}] Skipping return statement (end of subtask)")
                executed_stmt_texts.append(stmt_code)
                continue

            # Handle statements containing return (e.g., if blocks with return inside)
            if contains_return(stmt):
                print(f"[STEP {next_index + 1}] Removing return statements from: {stmt_code[:80]}...")
                stmt = remove_returns_from_stmt(stmt)
                ast.fix_missing_locations(stmt)

            # Log what we're executing
            if is_compound:
                if compound_has_skill:
                    print(f"[STEP {next_index + 1}] Executing compound statement with skill calls ({type(stmt).__name__}): {stmt_code[:80]}...")
                else:
                    print(f"[STEP {next_index + 1}] Executing compound statement ({type(stmt).__name__}): {stmt_code[:80]}...")
            elif is_skill:
                print(f"\n[STEP {next_index + 1}] Executing skill: {stmt_code[:100]}...")

            try:
                # Compile and execute single statement
                stmt_module = ast.Module(body=[stmt], type_ignores=[])
                ast.fix_missing_locations(stmt_module)
                compiled = compile(stmt_module, '<string>', 'exec')
                exec(compiled, context)

                executed_stmt_texts.append(stmt_code)

                # Update state tracker after execution
                # For compound statements with skill calls, also update state
                if is_skill or compound_has_skill:
                    state_tracker.confirm_state(
                        env=env,
                        statement_index=next_index,
                        skill_name=skill_name if is_skill else "compound",
                        target_obj=None,
                    )
                    if is_skill:
                        skill_call_count += 1
                    print(f"[STATE] Confirmed state after step {next_index + 1}: "
                          f"gripper_open={state_tracker.confirmed_state.gripper.is_open}, "
                          f"held_obj={state_tracker.confirmed_state.gripper.held_object}")

                # Wait for background validation to complete before checking repair
                # This ensures batch repair from validation is applied before next statement
                if validation_thread is not None and validation_thread.is_alive():
                    print(f"[BG_VALIDATE] Waiting for background validation to complete...")
                    validation_thread.join(timeout=10.0)  # Wait up to 10 seconds
                    if validation_thread.is_alive():
                        print(f"[BG_VALIDATE] Timeout waiting for validation, continuing...")
                    else:
                        print(f"[BG_VALIDATE] Background validation completed")

                # Check for pending repair from background validation
                if apply_pending_repair():
                    executed_stmt_texts = []
                    continue

                # Optional: revise_on_step after each skill (or compound with skills)
                # Disabled by default as it causes excessive LLM calls
                if self.enable_revise_on_step and self.mode == "remote_llm":
                    scene = self._get_scene_info(env)
                    last_step = {
                        "step_index": next_index + 1,
                        "skill_call_count": skill_call_count,
                        "statement": stmt_code,
                        "skill_name": skill_name if is_skill else "compound_with_skills",
                        "success": True,
                        "confirmed_state": {
                            "gripper_open": state_tracker.confirmed_state.gripper.is_open,
                            "held_object": state_tracker.confirmed_state.gripper.held_object,
                        }
                    }
                    new_code = self.revise_subtask(
                        subtask_name=subtask_name,
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        scene=scene,
                        last_step=last_step,
                    )
                    if new_code:
                        print(f"[REVISE] Code revised after step {next_index + 1}")
                        current_code = new_code
                        result.generated_code = new_code

            except Exception as e:
                # Execution failed
                error_msg = str(e)
                print(f"[FAILURE] Step {next_index + 1} failed: {error_msg}")

                if self.mode == "static":
                    result.error = f"Subtask failed (static mode): {error_msg}"
                    result.repair_count = 0
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=0,
                        result=result,
                    )

                if repair_count >= max_repairs:
                    result.error = f"Max repairs ({max_repairs}) exceeded: {error_msg}"
                    result.repair_count = repair_count
                    result.bg_repair_count = bg_repair_count
                    result.original_code = original_code
                    result.repair_history = repair_history
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=repair_count,
                        result=result,
                    )

                repair_count += 1
                self.total_repair_count += 1
                self.total_failure_count += 1
                print(f"[REPAIR] Attempting repair {repair_count}/{max_repairs}")

                repair_start = time.time()
                try:
                    scene = self._get_scene_info(env)
                    error_payload = {
                        "step_index": next_index + 1,
                        "statement": stmt_code,
                        "skill_name": skill_name if is_skill else None,
                        "error": error_msg,
                        "error_type": type(e).__name__,
                    }

                    # Use repair_subtask with statement context
                    new_code = self.repair_subtask(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        target_name=target_name,
                        error_info=error_payload,
                        scene=scene,
                        task_name=self.task_name,
                    )
                    repair_duration = time.time() - repair_start
                    self._log_llm_call("repair", subtask_name, repair_duration, self.last_response_num_tokens)

                    if new_code:
                        # Record repair in history
                        repair_record = RepairRecord(
                            repair_type="repair_on_failure",
                            violations=[f"{error_payload.get('error_type', 'Error')}: {error_msg}"],
                            code_before=current_code,
                            code_after=new_code,
                            statement_index=next_index,
                            duration_sec=repair_duration,
                            num_tokens=self.last_response_num_tokens,
                        )
                        repair_history.append(repair_record)
                        current_code = new_code
                        result.generated_code = new_code
                        # Note: Keep executed_stmt_texts to preserve already executed prefix
                    else:
                        result.error = "Repair returned empty code"
                        result.repair_count = repair_count
                        result.bg_repair_count = bg_repair_count
                        result.original_code = original_code
                        result.repair_history = repair_history
                        result.skipped = True
                        raise SubtaskSkip(
                            subtask_name=subtask_name,
                            obj_name=obj_name,
                            repair_attempts=repair_count,
                            result=result,
                        )

                except SubtaskSkip:
                    raise
                except Exception as repair_error:
                    repair_duration = time.time() - repair_start
                    self._log_llm_call("repair_failed", subtask_name, repair_duration, None)
                    print(f"[REPAIR_FAILED] {repair_error}")
                    result.error = f"Repair failed: {repair_error}"
                    result.repair_count = repair_count
                    result.bg_repair_count = bg_repair_count
                    result.original_code = original_code
                    result.repair_history = repair_history
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=repair_count,
                        result=result,
                    )

        result.error = "Unexpected exit from execution loop"
        result.repair_count = repair_count
        result.bg_repair_count = bg_repair_count
        result.original_code = original_code
        result.repair_history = repair_history
        return result

    def _create_subtask_execution_context(
        self,
        env,
        subtask_name: str,
        task=None,
    ) -> Dict[str, Any]:
        import numpy as np
        import math
        import time as time_module
        import skill_code

        context = {
            'env': env,
            'task': task,
            'np': np,
            'numpy': np,  # LLM may use either name
            'math': math,
            'time': time_module,
            # Common skill functions
            'move_gripper_to': skill_code.move_gripper_to,
            'move_to_position': skill_code.move_to_position,
            'move_parallel': skill_code.move_parallel,
            'rotate_gripper': skill_code.rotate_gripper,
        }

        # Always register ALL skill functions to avoid NameError from LLM-generated code
        # Generic names (aliased based on target robot)
        if self.target_robot == "panda":
            context.update({
                'open_gripper': skill_code.open_gripper,
                'close_gripper': skill_code.close_gripper,
                'pick': skill_code.pick,
                'place': skill_code.place,
                'grasp_handle': skill_code.grasp_handle,
                'release_handle': skill_code.release_handle,
            })
        elif self.target_robot == "robotiq":
            context.update({
                'open_gripper': skill_code.open_robotiq85,
                'close_gripper': skill_code.close_robotiq85,
                'pick': skill_code.pick_robotiq85,
                'place': skill_code.place_robotiq85,
                'grasp_handle': skill_code.grasp_handle_robotiq85,
                'release_handle': skill_code.release_handle_robotiq85,
            })
        else:
            context.update({
                'activate_vacuum': skill_code.activate_vacuum,
                'deactivate_vacuum': skill_code.deactivate_vacuum,
                'attach_vacuum_handle': skill_code.attach_vacuum_handle,
                'detach_vacuum_handle': skill_code.detach_vacuum_handle,
            })

        # Also register ALL robot-specific names directly (LLM may use any of these)
        # Panda
        context.update({
            'open_panda': getattr(skill_code, 'open_gripper', None),
            'close_panda': getattr(skill_code, 'close_gripper', None),
        })
        # Robotiq85
        context.update({
            'open_robotiq85': getattr(skill_code, 'open_robotiq85', None),
            'close_robotiq85': getattr(skill_code, 'close_robotiq85', None),
            'pick_robotiq85': getattr(skill_code, 'pick_robotiq85', None),
            'place_robotiq85': getattr(skill_code, 'place_robotiq85', None),
            'grasp_handle_robotiq85': getattr(skill_code, 'grasp_handle_robotiq85', None),
            'release_handle_robotiq85': getattr(skill_code, 'release_handle_robotiq85', None),
        })
        # Suction
        context.update({
            'activate_suction': getattr(skill_code, 'activate_vacuum', None),
            'deactivate_suction': getattr(skill_code, 'deactivate_vacuum', None),
            'attach_suction': getattr(skill_code, 'attach_vacuum_handle', None),
            'detach_suction': getattr(skill_code, 'detach_vacuum_handle', None),
        })

        # Load helper functions from source subtask file
        source_subtask_path = f"./subtasks_{self.source_robot}/{subtask_name}.py"
        try:
            with open(source_subtask_path, "r", encoding="utf-8") as f:
                source_code = f.read()
            helper_funcs = self._extract_helper_functions(source_code, subtask_name)
            context.update(helper_funcs)
        except FileNotFoundError:
            pass

        return context

    def _execute_subtask_on_failure(
        self,
        subtask_name: str,
        env,
        obj_name: Optional[str] = None,
        target_name: Optional[str] = None,
        check_func: Optional[callable] = None,
        max_repairs: int = 5,
        **kwargs,
    ) -> "SubtaskResult":
        print(f"[DEBUG] _execute_subtask_on_failure called for {subtask_name}")
        print(f"[DEBUG]   obj_name: {obj_name}, target_name: {target_name}")
        print(f"[DEBUG]   check_func: {'provided' if check_func is not None else 'None'}")

        result = SubtaskResult(
            subtask_name=subtask_name,
            obj_name=obj_name,
            target_name=target_name,
        )

        scene_info = self._get_scene_info(env)
        print(f"[DEBUG] Scene info keys: {list(scene_info.keys()) if isinstance(scene_info, dict) else type(scene_info)}")

        # Generate subtask code
        start_time = time.time()
        try:
            subtask_code = self.generate_subtask_code(
                subtask_name=subtask_name,
                obj_name=obj_name,
                target_name=target_name,
                scene_info=scene_info,
            )
            result.generated_code = subtask_code
            gen_duration = time.time() - start_time
            self._log_llm_call("generate", subtask_name, gen_duration, self.last_response_num_tokens)

            # Save generated code to file
            self._save_subtask(subtask_name, subtask_code)

        except Exception as e:
            gen_duration = time.time() - start_time
            self._log_llm_call("generate_failed", subtask_name, gen_duration, None)
            result.error = f"Failed to generate subtask code: {e}"
            self.total_failure_count += 1
            raise SubtaskFailure(
                subtask_name=subtask_name,
                obj_name=obj_name,
                target_name=target_name,
                message=result.error,
            )

        repair_count = 0
        original_code = subtask_code
        repair_history: List[RepairRecord] = []

        while repair_count <= max_repairs:
            print(f"\n[DEBUG] (on_failure mode) Repair loop iteration, repair_count={repair_count}/{max_repairs}")
            try:
                print(f"[DEBUG] (on_failure mode) Executing subtask code...")
                exec_success = self._execute_subtask_code(
                    subtask_code=subtask_code,
                    subtask_name=subtask_name,
                    env=env,
                    obj_name=obj_name,
                    target_name=target_name,
                    **kwargs,
                )
                print(f"[DEBUG] (on_failure mode) _execute_subtask_code returned: {exec_success}")

                if not exec_success:
                    self.total_failure_count += 1
                    raise SubtaskFailure(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        target_name=target_name,
                        message=f"{subtask_name} returned False",
                    )

                if check_func is not None:
                    print(f"[DEBUG] (on_failure mode) Calling check_func for {subtask_name}...")
                    check_passed = check_func(env, obj_name, target_name)
                    print(f"[DEBUG] (on_failure mode) check_func returned: {check_passed}")
                    if not check_passed:
                        print(f"[DEBUG] (on_failure mode) check_func FAILED")
                        self.total_failure_count += 1

                        # Gather detailed failure context
                        failure_context = self._get_check_failure_context(
                            env, subtask_name, obj_name, target_name
                        )

                        raise SubtaskFailure(
                            subtask_name=subtask_name,
                            obj_name=obj_name,
                            target_name=target_name,
                            context=failure_context,
                            message=f"Check failed for {subtask_name}: {failure_context.get('reason', 'unknown')}",
                        )
                else:
                    print(f"[DEBUG] (on_failure mode) No check_func provided")

                print(f"[DEBUG] (on_failure mode) Subtask {subtask_name} completed successfully")
                result.success = True
                result.repair_count = repair_count
                result.original_code = original_code
                result.repair_history = repair_history
                return result

            except SubtaskSkip:
                # SubtaskSkip should propagate up, not be caught here
                raise
            except (SubtaskFailure, SyntaxError, RuntimeError, Exception) as e:
                # Extract error message based on exception type
                if isinstance(e, SubtaskFailure):
                    error_message = e.message
                    error_info = e.to_dict()
                else:
                    # Handle SyntaxError, RuntimeError, and other exceptions
                    import traceback
                    error_message = f"{type(e).__name__}: {str(e)}"
                    error_info = {
                        "subtask_name": subtask_name,
                        "obj_name": obj_name,
                        "target_name": target_name,
                        "message": error_message,
                        "traceback": traceback.format_exc(),
                    }

                print(f"[SUBTASK_FAILURE] {subtask_name}: {error_message}")

                if self.mode == "static":
                    result.error = f"Subtask failed (static mode, no repair): {error_message}"
                    result.repair_count = 0
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=0,
                        result=result,
                    )

                if repair_count >= max_repairs:
                    result.error = f"Max repairs ({max_repairs}) exceeded: {error_message}"
                    result.repair_count = repair_count
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=repair_count,
                        result=result,
                    )

                repair_count += 1
                self.total_repair_count += 1
                print(f"[SUBTASK_REPAIR] Attempting repair {repair_count}/{max_repairs} for {subtask_name}")

                repair_start = time.time()
                try:
                    new_code = self.repair_subtask(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        target_name=target_name,
                        error_info=error_info,
                        scene=self._get_scene_info(env),
                        task_name=self.task_name,
                    )
                    repair_duration = time.time() - repair_start
                    self._log_llm_call("repair", subtask_name, repair_duration, self.last_response_num_tokens)

                    if new_code:
                        # Record repair in history
                        repair_record = RepairRecord(
                            repair_type="repair_on_failure",
                            violations=[error_message],
                            code_before=subtask_code,
                            code_after=new_code,
                            statement_index=None,
                            duration_sec=repair_duration,
                            num_tokens=self.last_response_num_tokens,
                        )
                        repair_history.append(repair_record)
                        subtask_code = new_code
                        result.generated_code = new_code
                    else:
                        result.error = "Repair returned empty code"
                        result.repair_count = repair_count
                        result.skipped = True
                        raise SubtaskSkip(
                            subtask_name=subtask_name,
                            obj_name=obj_name,
                            repair_attempts=repair_count,
                            result=result,
                        )
                except SubtaskSkip:
                    raise
                except Exception as repair_error:
                    repair_duration = time.time() - repair_start
                    self._log_llm_call("repair_failed", subtask_name, repair_duration, None)
                    print(f"[SUBTASK_REPAIR_FAILED] {repair_error}")
                    result.error = f"Repair failed: {repair_error}"
                    result.repair_count = repair_count
                    result.skipped = True
                    raise SubtaskSkip(
                        subtask_name=subtask_name,
                        obj_name=obj_name,
                        repair_attempts=repair_count,
                        result=result,
                    )

            except SubtaskSkip:
                result.skipped = True
                result.original_code = original_code
                result.repair_history = repair_history
                raise

        result.error = "Unexpected exit from repair loop"
        result.repair_count = repair_count
        result.original_code = original_code
        result.repair_history = repair_history
        return result

    def _get_check_failure_context(
        self,
        env,
        subtask_name: str,
        obj_name: Optional[str],
        target_name: Optional[str],
    ) -> Dict[str, Any]:
        context = {}

        try:
            # Gripper state
            gripper_open = env.gripper_is_open()
            context["gripper_open"] = gripper_open

            # Object in gripper
            if obj_name:
                in_gripper = env.obj_in_gripper(obj_name)
                context["obj_in_gripper"] = in_gripper
                obj_pos = env.get_obj_pos(obj_name)
                context["obj_pos"] = [round(p, 3) for p in obj_pos]

            # Target info
            if target_name:
                target_pos = env.get_obj_pos(target_name)
                target_bbox = env.get_obj_bbox(target_name)
                context["target_pos"] = [round(p, 3) for p in target_pos]
                context["target_bbox"] = [
                    [round(p, 3) for p in target_bbox[0]],
                    [round(p, 3) for p in target_bbox[1]]
                ]

                # Check if object is inside target
                if obj_name:
                    obj_bbox = env.get_obj_bbox(obj_name)
                    obj_min, obj_max = obj_bbox[0], obj_bbox[1]
                    target_min, target_max = target_bbox[0], target_bbox[1]

                    # Check each axis
                    inside_x = obj_min[0] >= target_min[0] and obj_max[0] <= target_max[0]
                    inside_y = obj_min[1] >= target_min[1] and obj_max[1] <= target_max[1]
                    inside_z = obj_min[2] >= target_min[2] - 0.01  # Allow slight tolerance

                    context["obj_inside_target_x"] = inside_x
                    context["obj_inside_target_y"] = inside_y
                    context["obj_inside_target_z"] = inside_z

            # Determine failure reason based on subtask type
            if "pick" in subtask_name.lower():
                if gripper_open:
                    context["reason"] = "Gripper is open - object not grasped"
                elif obj_name and not context.get("obj_in_gripper", False):
                    context["reason"] = f"Object '{obj_name}' not in gripper after close"
                else:
                    context["reason"] = "Pick verification failed"

            elif "place" in subtask_name.lower():
                if obj_name and context.get("obj_in_gripper", False) and not gripper_open:
                    context["reason"] = f"Object '{obj_name}' still in gripper - not released"
                elif target_name and obj_name:
                    if not context.get("obj_inside_target_x", True):
                        context["reason"] = f"Object not inside target on X axis"
                    elif not context.get("obj_inside_target_y", True):
                        context["reason"] = f"Object not inside target on Y axis"
                    else:
                        context["reason"] = f"Object '{obj_name}' not inside '{target_name}'"
                else:
                    context["reason"] = "Place verification failed"

            elif "unstack" in subtask_name.lower():
                if obj_name:
                    obj_z = context.get("obj_pos", [0, 0, 0.6])[2]
                    if obj_z > 0.52 and not context.get("obj_in_gripper", False):
                        context["reason"] = f"Object '{obj_name}' still stacked (z={obj_z:.3f})"
                    else:
                        context["reason"] = "Unstack verification failed"
                else:
                    context["reason"] = "Unstack verification failed"
            else:
                context["reason"] = "Check verification failed"

        except Exception as e:
            context["reason"] = f"Error gathering context: {str(e)}"

        return context

    # ========================================================================
    # Code Execution
    # ========================================================================

    def _execute_subtask_code(
        self,
        subtask_code: str,
        subtask_name: str,
        env,
        obj_name: Optional[str] = None,
        target_name: Optional[str] = None,
        **kwargs,
    ) -> bool:
        print(f"[DEBUG] _execute_subtask_code called for {subtask_name}")
        print(f"[DEBUG]   subtask_code length: {len(subtask_code)}")
        print(f"[DEBUG]   obj_name: {obj_name}, target_name: {target_name}")

        import inspect
        import numpy as np
        import math
        import time as time_module
        import skill_code

        # Log raw code before preprocessing
        print(f"[EXEC] Raw subtask code for {subtask_name}:")
        print(subtask_code)

        subtask_code = self._postprocess_subtask_code(subtask_code)
        print(f"[DEBUG] After postprocess, code length: {len(subtask_code)}")

        # Get task from kwargs or from env
        task = kwargs.pop('task', None)
        if task is None and hasattr(env, 'env') and hasattr(env.env, 'task'):
            task = env.env.task

        # Log preprocessed code
        print(f"[EXEC] Preprocessed subtask code for {subtask_name}:")
        print(subtask_code)

        # Build context with ALL skill functions for all robot types
        context = {
            'env': env,
            'task': task,
            'np': np,
            'numpy': np,
            'math': math,
            'time': time_module,
            'move_gripper_to': skill_code.move_gripper_to,
            'move_to_position': skill_code.move_to_position,
            'move_parallel': skill_code.move_parallel,
            'rotate_gripper': skill_code.rotate_gripper,
            # Panda gripper functions (generic + specific names)
            'open_gripper': getattr(skill_code, 'open_gripper', None),
            'close_gripper': getattr(skill_code, 'close_gripper', None),
            'pick': getattr(skill_code, 'pick', None),
            'place': getattr(skill_code, 'place', None),
            'grasp_handle': getattr(skill_code, 'grasp_handle', None),
            'release_handle': getattr(skill_code, 'release_handle', None),
            'open_panda': getattr(skill_code, 'open_gripper', None),
            'close_panda': getattr(skill_code, 'close_gripper', None),
            # Robotiq85 gripper functions
            'open_robotiq85': getattr(skill_code, 'open_robotiq85', None),
            'close_robotiq85': getattr(skill_code, 'close_robotiq85', None),
            'pick_robotiq85': getattr(skill_code, 'pick_robotiq85', None),
            'place_robotiq85': getattr(skill_code, 'place_robotiq85', None),
            'grasp_handle_robotiq85': getattr(skill_code, 'grasp_handle_robotiq85', None),
            'release_handle_robotiq85': getattr(skill_code, 'release_handle_robotiq85', None),
            # Suction gripper functions
            'activate_vacuum': getattr(skill_code, 'activate_vacuum', None),
            'deactivate_vacuum': getattr(skill_code, 'deactivate_vacuum', None),
            'attach_vacuum_handle': getattr(skill_code, 'attach_vacuum_handle', None),
            'detach_vacuum_handle': getattr(skill_code, 'detach_vacuum_handle', None),
            'activate_suction': getattr(skill_code, 'activate_vacuum', None),
            'deactivate_suction': getattr(skill_code, 'deactivate_vacuum', None),
            'attach_suction': getattr(skill_code, 'attach_vacuum_handle', None),
            'detach_suction': getattr(skill_code, 'detach_vacuum_handle', None),
        }

        # Load helper functions from source subtask file into context
        source_subtask_path = f"./subtasks_{self.source_robot}/{subtask_name}.py"
        try:
            with open(source_subtask_path, "r", encoding="utf-8") as f:
                source_code = f.read()
            helper_funcs = self._extract_helper_functions(source_code, subtask_name)
            context.update(helper_funcs)
        except FileNotFoundError:
            pass

        print(f"[DEBUG] Executing subtask code via exec()...")
        exec(subtask_code, context)

        subtask_func = context.get(subtask_name)
        if subtask_func is None:
            print(f"[DEBUG] ERROR: subtask function '{subtask_name}' not found in context")
            print(f"[DEBUG] Available context keys: {[k for k in context.keys() if not k.startswith('_')]}")
            raise RuntimeError(f"Subtask function '{subtask_name}' not found in generated code")

        print(f"[DEBUG] Found subtask function: {subtask_name}")
        sig = inspect.signature(subtask_func)
        params = list(sig.parameters.items())
        param_names = [name for name, _ in params]
        print(f"[DEBUG] Function params: {param_names}")

        call_kwargs = dict(kwargs)
        special_args = {
            'obj_name': obj_name,
            'target_name': target_name,
        }

        for arg_name, arg_value in special_args.items():
            if arg_value is not None and arg_name in param_names:
                call_kwargs[arg_name] = arg_value

        for i, (param_name, param) in enumerate(params[1:], start=1):
            if param_name in call_kwargs:
                continue

            if obj_name is not None and 'obj_name' not in param_names:
                if i == 1 and param_name not in call_kwargs:
                    if param_name not in special_args:
                        call_kwargs[param_name] = obj_name
                        continue

            if target_name is not None and 'target_name' not in param_names:
                if param_name not in call_kwargs:
                    if 'target' in param_name or 'dest' in param_name:
                        call_kwargs[param_name] = target_name
                        continue

        print(f"[DEBUG] Calling subtask function with kwargs: {list(call_kwargs.keys())}")
        result = subtask_func(env, **call_kwargs)
        print(f"[DEBUG] Subtask function returned: {result} (type: {type(result).__name__})")

        if result is None:
            print(f"[DEBUG] Result is None, returning True")
            return True
        final_result = bool(result)
        print(f"[DEBUG] Returning bool(result): {final_result}")
        return final_result

    def _postprocess_subtask_code(self, code: str) -> str:
        code = code.strip()

        if code.startswith("```python"):
            code = code[len("```python"):]
        elif code.startswith("```"):
            code = code[len("```"):]

        if code.endswith("```"):
            code = code[:-len("```")]

        return code.strip()
