import ast
import json
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

try:
    from ontology_modules.generate_ontology import (
        load_skills_as_dict,
        infer_parameter_mapping,
        generate_function_signature,
        generate_interface_example,
        generate_param_mapping_text,
        compute_quality_grade,
    )
    from ontology_modules.parameter_transforms import (
        has_transform,
        apply_special_transform,
        generate_transformed_interface,
    )
except Exception:
    from generate_ontology import (
        load_skills_as_dict,
        infer_parameter_mapping,
        generate_function_signature,
        generate_interface_example,
        generate_param_mapping_text,
        compute_quality_grade,
    )
    try:
        from parameter_transforms import (
            has_transform,
            apply_special_transform,
            generate_transformed_interface,
        )
    except ImportError:
        def has_transform(src, tgt): return False
        def apply_special_transform(src, tgt, params, ctx=None): return None
        def generate_transformed_interface(src, tgt, params): return None


_SKILL_SUB_PATHS = [
    Path("ontology_modules/output_mapping/skill_substitutions.json"),
    Path("skill_substitutions.json"),
    Path("./skill_substitutions.json"),
]

_ROBOT_SKILL_PATHS = {
    "panda": [
        Path("ontology_modules/data/panda_skills.json"),
        Path("panda_skills.json"),
        Path("./panda_skills.json"),
    ],
    "ur5": [
        Path("ontology_modules/data/ur5_skills.json"),
        Path("ur5_skills.json"),
        Path("./ur5_skills.json"),
    ],
    "sawyer": [
        Path("ontology_modules/data/sawyer_skills.json"),
        Path("sawyer_skills.json"),
        Path("./sawyer_skills.json"),
    ],
}

def _read_json_any(paths: List[Path]) -> Dict[str, Any]:
    last_err = None
    for p in paths:
        try:
            if p.exists():
                with open(p, "r", encoding="utf-8") as f:
                    return json.load(f)
        except Exception as e:
            last_err = e
    raise FileNotFoundError(f"Could not read json from paths={paths}. last_err={last_err}")


def _sim(a: str, b: str) -> float:
    a = (a or "").strip().lower()
    b = (b or "").strip().lower()
    if not a and not b:
        return 1.0
    if not a or not b:
        return 0.0
    return SequenceMatcher(None, a, b).ratio()


def _safe_unparse(node: ast.AST) -> str:
    try:
        return ast.unparse(node)
    except Exception:
        return ""


def _load_skill_substitutions() -> Dict[str, Any]:
    return _read_json_any(_SKILL_SUB_PATHS)


def _robot_skills_path(robot: str) -> Path:
    if robot not in _ROBOT_SKILL_PATHS:
        raise ValueError(f"Unknown robot={robot}. supported={list(_ROBOT_SKILL_PATHS.keys())}")
    for p in _ROBOT_SKILL_PATHS[robot]:
        if p.exists():
            return p
    raise FileNotFoundError(f"Skills file not found for robot={robot}. Tried={_ROBOT_SKILL_PATHS[robot]}")


def _load_robot_skills_as_dict(robot: str) -> Dict[str, Dict[str, Any]]:
    return load_skills_as_dict(_robot_skills_path(robot))

def _extract_skill_calls(tree: ast.AST) -> List[Dict[str, Any]]:
    calls: List[Dict[str, Any]] = []

    class V(ast.NodeVisitor):
        def visit_Call(self, node: ast.Call):
            if isinstance(node.func, ast.Name):
                kw = {}
                for k in node.keywords or []:
                    if k.arg:
                        kw[k.arg] = _safe_unparse(k.value)
                calls.append(
                    {
                        "name": node.func.id,
                        "lineno": getattr(node, "lineno", None),
                        "end_lineno": getattr(node, "end_lineno", None),
                        "keywords": kw,
                        "args_count": len(node.args),
                    }
                )
            self.generic_visit(node)

    V().visit(tree)
    return calls


def _extract_objects_in_code(tree: ast.AST) -> List[Dict[str, Any]]:
    objs: List[Dict[str, Any]] = []
    allowed_ctors = {"Shape", "Joint", "Dummy", "ProximitySensor"}

    class V(ast.NodeVisitor):
        def visit_Assign(self, node: ast.Assign):
            if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
                var = node.targets[0].id
                val = node.value
                if isinstance(val, ast.Call) and isinstance(val.func, ast.Name):
                    ctor = val.func.id
                    if ctor in allowed_ctors:
                        obj_id = None
                        if val.args:
                            a0 = val.args[0]
                            if isinstance(a0, ast.Constant) and isinstance(a0.value, str):
                                obj_id = a0.value
                        objs.append(
                            {
                                "var_name": var,
                                "ctor": ctor,
                                "object_id_str": obj_id,
                                "lineno": getattr(node, "lineno", None),
                            }
                        )
            self.generic_visit(node)

    V().visit(tree)
    return objs


def _best_equivalence(
    subs: Dict[str, Any],
    source_robot: str,
    target_robot: str,
    source_skill_name: str,
) -> Optional[Dict[str, Any]]:
    key_a = f"{source_robot}:{source_skill_name}"
    best = None
    best_score = -1.0
    for e in subs.get("equivalences", []):
        if e.get("skill_a") != key_a:
            continue
        b = e.get("skill_b", "")
        if not b.startswith(f"{target_robot}:"):
            continue
        score = float(e.get("score", 0.0))
        if score > best_score:
            best = e
            best_score = score
    return best


def _get_decomposition(
    subs: Dict[str, Any],
    source_robot: str,
    source_skill_name: str,
) -> Optional[Dict[str, Any]]:
    key = f"{source_robot}:{source_skill_name}"
    for d in subs.get("decompositions", []):
        if d.get("target_skill") == key:
            return d
    return None


def _expand_decomposition_to_target(
    subs: Dict[str, Any],
    source_robot: str,
    target_robot: str,
    decomposition: List[str],
) -> List[Dict[str, Any]]:
    expanded = []
    for sub_skill in decomposition:
        if ":" in sub_skill:
            _, sub_name = sub_skill.split(":", 1)
        else:
            sub_name = sub_skill

        sub_eq = _best_equivalence(subs, source_robot, target_robot, sub_name)
        if sub_eq:
            tgt_full = sub_eq.get("skill_b", "")
            _, tgt_name = tgt_full.split(":", 1) if ":" in tgt_full else ("", tgt_full)
            expanded.append({
                "source_skill": sub_skill,
                "target_skill": tgt_full,
                "target_skill_name": tgt_name,
                "equivalence": sub_eq,
            })
        else:
            expanded.append({
                "source_skill": sub_skill,
                "target_skill": None,
                "target_skill_name": None,
                "equivalence": None,
            })
    return expanded


def _flatten_params_for_mapping(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    for p in params:
        if p.get("type") == "dict" and isinstance(p.get("schema"), dict):
            parent = p.get("name", "dictparam")
            schema = p["schema"]
            for k, v in schema.items():
                out.append(
                    {
                        "name": f"{parent}.{k}",
                        "type": v.get("type", ""),
                        "semantic": v.get("semantic", ""),
                        "units": v.get("units", ""),
                    }
                )
        else:
            out.append(p)
    return out


def _required_keywords_for_match(source_skill_spec: Dict[str, Any], callsite_keywords: Dict[str, str]) -> List[str]:
    req = []
    for p in source_skill_spec.get("parameters", []):
        if p.get("required") and p.get("name"):
            if p["name"] == "target_quat":
                continue
            req.append(p["name"])
    for k in callsite_keywords.keys():
        if k not in req:
            req.append(k)
    return req


def _map_object_to_target(
    src_name: str,
    src_type: str,
    target_scene_info: List[Dict[str, str]],
) -> Dict[str, Any]:
    def norm(s: str) -> str:
        return (s or "").strip().lower()

    for t in target_scene_info:
        if norm(t.get("name")) == norm(src_name):
            type_match = 1.0 if norm(t.get("type")) == norm(src_type) else 0.6
            return {
                "target_object": t.get("name"),
                "target_type": t.get("type"),
                "confidence": round(0.95 * type_match, 3),
                "method": "exact_name",
            }

    same_type = [t for t in target_scene_info if norm(t.get("type")) == norm(src_type)]
    best_t, best_s = None, -1.0
    for t in same_type:
        s = _sim(src_name, t.get("name", ""))
        if s > best_s:
            best_s, best_t = s, t
    if best_t and best_s >= 0.45:
        conf = 0.55 + 0.4 * best_s
        return {
            "target_object": best_t.get("name"),
            "target_type": best_t.get("type"),
            "confidence": round(float(conf), 3),
            "method": "same_type_fuzzy",
        }

    best_t, best_s = None, -1.0
    for t in target_scene_info:
        s = _sim(src_name, t.get("name", ""))
        if s > best_s:
            best_s, best_t = s, t
    if best_t:
        conf = 0.25 + 0.45 * best_s
        return {
            "target_object": best_t.get("name"),
            "target_type": best_t.get("type"),
            "confidence": round(float(conf), 3),
            "method": "global_fuzzy",
        }

    return {
        "target_object": src_name,
        "target_type": None,
        "confidence": 0.1,
        "method": "fallback_identity",
    }


def get_initial_guidance(
    source_robot: str,
    target_robot: str,
    reference_code: str,
    target_scene_info: List[Dict[str, str]],
    grasp_guidance: Optional[str] = None,
) -> Dict[str, Any]:
    tree = ast.parse(reference_code)

    subs = _load_skill_substitutions()
    src_skill_index = _load_robot_skills_as_dict(source_robot)
    tgt_skill_index = _load_robot_skills_as_dict(target_robot)

    calls = _extract_skill_calls(tree)

    primitive_mappings: List[Dict[str, Any]] = []
    unmapped_skills: List[Dict[str, Any]] = []

    DECOMPOSITION_THRESHOLD = 0.8

    for c in calls:
        src_skill_name = c["name"]

        if src_skill_name not in src_skill_index:
            continue

        eq = _best_equivalence(subs, source_robot, target_robot, src_skill_name)
        decomp = _get_decomposition(subs, source_robot, src_skill_name)

        use_decomposition = False
        if decomp:
            decomp_sub_skills = decomp.get("decomposition", [])
            if not eq:
                use_decomposition = True
            elif float(eq.get("score", 0.0)) < DECOMPOSITION_THRESHOLD:
                expanded = _expand_decomposition_to_target(subs, source_robot, target_robot, decomp_sub_skills)
                all_mapped = all(e.get("equivalence") is not None for e in expanded)
                if all_mapped:
                    use_decomposition = True

        if use_decomposition and decomp:
            decomp_sub_skills = decomp.get("decomposition", [])
            expanded = _expand_decomposition_to_target(subs, source_robot, target_robot, decomp_sub_skills)

            sub_guidances = []
            all_interfaces = []
            all_param_mappings = []
            combined_risk_flags = []

            for exp in expanded:
                sub_src_skill = exp["source_skill"]
                sub_tgt_skill_name = exp.get("target_skill_name")
                sub_eq = exp.get("equivalence")

                if not sub_tgt_skill_name or sub_tgt_skill_name not in tgt_skill_index:
                    continue

                if ":" in sub_src_skill:
                    _, sub_src_name = sub_src_skill.split(":", 1)
                else:
                    sub_src_name = sub_src_skill

                sub_src_spec = src_skill_index.get(sub_src_name, {})
                sub_tgt_spec = tgt_skill_index[sub_tgt_skill_name]

                sub_src_params = sub_src_spec.get("parameters", [])
                sub_tgt_params = sub_tgt_spec.get("parameters", [])
                sub_tgt_params_flat = _flatten_params_for_mapping(sub_tgt_params)

                sub_param_mappings, sub_risk_flags = infer_parameter_mapping(sub_src_params, sub_tgt_params_flat)
                combined_risk_flags.extend(sub_risk_flags)
                all_param_mappings.extend(sub_param_mappings)

                sub_signature = generate_function_signature(sub_tgt_spec)
                sub_interface = generate_interface_example(sub_tgt_spec, sub_param_mappings)
                all_interfaces.append(sub_interface)

                sub_guidances.append({
                    "source_skill": sub_src_skill,
                    "target_skill": exp["target_skill"],
                    "target_skill_name": sub_tgt_skill_name,
                    "signature": sub_signature,
                    "interface": sub_interface,
                    "equivalence_score": float(sub_eq.get("score", 0.0)) if sub_eq else 0.0,
                    "parameter_mappings": sub_param_mappings,
                })

            avg_equiv_score = sum(sg.get("equivalence_score", 0) for sg in sub_guidances) / max(len(sub_guidances), 1)
            quality_grade = compute_quality_grade(avg_equiv_score, all_param_mappings, combined_risk_flags)

            target_calls = " -> ".join([sg["target_skill_name"] for sg in sub_guidances])
            content = f"Decompose the `{src_skill_name}(env, task, ...)` call into: {target_calls}"

            full_content = (
                f"Decompose the `{src_skill_name}(env, task, ...)` call into a sequence of calls:\n\n"
            )
            for i, sg in enumerate(sub_guidances, 1):
                full_content += f"Step {i}: `{sg['target_skill_name']}`\n"
                full_content += f"```python\n{sg['signature']}\n```\n\n"

            src_skill_spec = src_skill_index[src_skill_name]
            keywords_required = _required_keywords_for_match(src_skill_spec, c["keywords"])

            primitive_mappings.append(
                {
                    "line_start": c["lineno"],
                    "line_end": c["end_lineno"],
                    "match": {
                        "node": "call",
                        "func": {"type": "Name", "id": src_skill_name},
                        "keywords_required": keywords_required,
                    },
                    "type": "decomposition",
                    "content": content,
                    "full_content": full_content,
                    "interface": "\n".join(all_interfaces),
                    "equivalence_score": avg_equiv_score,
                    "decomposition": {
                        "original_skill": f"{source_robot}:{src_skill_name}",
                        "sub_skills": decomp_sub_skills,
                        "coverage_score": decomp.get("coverage_score", 0),
                        "expanded_guidance": sub_guidances,
                    },
                    "parameter_mappings": all_param_mappings,
                    "quality_grade": quality_grade,
                }
            )
            continue

        if not eq:
            unmapped_skills.append(
                {
                    "line_start": c["lineno"],
                    "line_end": c["end_lineno"],
                    "source_skill": f"{source_robot}:{src_skill_name}",
                    "reason": "No equivalence found in skill_substitutions",
                    "callsite_keywords": c["keywords"],
                }
            )
            continue

        tgt_full = eq["skill_b"]
        _, tgt_skill_name = tgt_full.split(":", 1)

        if tgt_skill_name not in tgt_skill_index:
            unmapped_skills.append(
                {
                    "line_start": c["lineno"],
                    "line_end": c["end_lineno"],
                    "source_skill": f"{source_robot}:{src_skill_name}",
                    "reason": f"Equivalence points to missing target skill spec: {tgt_full}",
                    "callsite_keywords": c["keywords"],
                }
            )
            continue

        src_skill_spec = src_skill_index[src_skill_name]
        tgt_skill_spec = tgt_skill_index[tgt_skill_name]

        src_params = src_skill_spec.get("parameters", [])
        tgt_params = tgt_skill_spec.get("parameters", [])
        tgt_params_flat = _flatten_params_for_mapping(tgt_params)

        special_transform_applied = False
        transformed_interface = None
        transformed_params = None

        if has_transform(src_skill_name, tgt_skill_name):
            source_params_dict = {}
            for kw_name, kw_value in c["keywords"].items():
                try:
                    source_params_dict[kw_name] = ast.literal_eval(kw_value)
                except (ValueError, SyntaxError):
                    source_params_dict[kw_name] = kw_value

            transformed_params = apply_special_transform(
                src_skill_name, tgt_skill_name, source_params_dict
            )
            if transformed_params:
                transformed_interface = generate_transformed_interface(
                    src_skill_name, tgt_skill_name, source_params_dict
                )
                special_transform_applied = True
                print(f"[GUIDANCE] Applied special transform: {src_skill_name} -> {tgt_skill_name}")
                print(f"[GUIDANCE]   Transformed params: {transformed_params}")
                print(f"[GUIDANCE]   Interface: {transformed_interface}")

        param_mappings, risk_flags = infer_parameter_mapping(src_params, tgt_params_flat)

        for m in param_mappings:
            src_key = m.get("source")
            if src_key in c["keywords"]:
                m["callsite_value_expr"] = c["keywords"][src_key]

        equiv_score = float(eq.get("score", 0.0))

        quality_grade = compute_quality_grade(equiv_score, param_mappings, risk_flags)

        target_signature = generate_function_signature(tgt_skill_spec)
        param_mapping_text = generate_param_mapping_text(param_mappings)

        if special_transform_applied and transformed_interface:
            interface_example = transformed_interface
            if transformed_params:
                param_mapping_text = "Parameter Mapping (via special transform):\n"
                for key, value in transformed_params.items():
                    if value is not None:
                        param_mapping_text += f"- {key}: {repr(value)}\n"
        else:
            interface_example = generate_interface_example(tgt_skill_spec, param_mappings)

        grasp_prefix_templates = None
        if grasp_guidance and src_skill_name.lower() in ("pick", "grasp"):
            pos_expr = c["keywords"].get("target_pos", "grasp_pos")

            grasp_prefix_templates = [
                f"ur5_move_to(env, task, target_pos={pos_expr} + np.array([0, 0, 0.15]), target_quat=grasp_quat, gripper=1.0, n_waypoints=10)",
                f"ur5_align_gripper(env, task, reference_quat=object_quat, approach_direction='down', yaw_mode='parallel')"
            ]
            print(f"[GUIDANCE] Created grasp_prefix_templates for grasp skill: {src_skill_name}")

        keywords_required = _required_keywords_for_match(src_skill_spec, c["keywords"])

        content = f"Replace the `{src_skill_name}(env, task, ...)` call with a `{tgt_skill_name}(...)` call."
        full_content = (
            f"Replace the `{src_skill_name}(env, task, ...)` call with a `{tgt_skill_name}(...)` call using the following interface:\n\n"
            f"```python\n{target_signature}\n```\n\n"
            f"{param_mapping_text}"
        )

        mapping_entry = {
            "line_start": c["lineno"],
            "line_end": c["end_lineno"],
            "match": {
                "node": "call",
                "func": {"type": "Name", "id": src_skill_name},
                "keywords_required": keywords_required,
            },
            "type": "capabilities",
            "content": content,
            "full_content": full_content,
            "interface": interface_example,
            "equivalence_score": equiv_score,
            "parameter_mappings": param_mappings,
            "quality_grade": quality_grade,
        }
        if grasp_prefix_templates:
            mapping_entry["grasp_prefix_templates"] = grasp_prefix_templates
        primitive_mappings.append(mapping_entry)

    code_objs = _extract_objects_in_code(tree)
    object_mappings: List[Dict[str, Any]] = []
    for o in code_objs:
        src_obj_name = o["object_id_str"] or o["var_name"]
        src_type = o["ctor"]

        mapped = _map_object_to_target(src_obj_name, src_type, target_scene_info)
        object_mappings.append(
            {
                "source_object": src_obj_name,
                "source_type": src_type,
                "target_object": mapped["target_object"],
                "target_type": mapped["target_type"],
                "confidence": mapped["confidence"],
                "method": mapped["method"],
                "var_name": o["var_name"],
                "lineno": o["lineno"],
            }
        )

    result = {
        "source_robot": source_robot,
        "target_robot": target_robot,
        "primitive_mappings": primitive_mappings,
        "object_mappings": object_mappings,
        "unmapped_skills": unmapped_skills,
        "meta": {
            "total_skill_calls_found": len([c for c in calls if c["name"] in src_skill_index]),
            "total_primitives_mapped": len(primitive_mappings),
            "total_objects_found": len(code_objs),
            "target_scene_objects": len(target_scene_info),
        },
    }

    if grasp_guidance and isinstance(grasp_guidance, str) and grasp_guidance.strip():
        result["grasp_guidance"] = grasp_guidance.strip()
        print(f"[GUIDANCE] Including grasp_guidance in result: {grasp_guidance[:100]}...")

    return result

from infilling_modules.infilling import (
    run_full_infilling_pipeline,
    load_model_and_tokenizer,
    generate_fim_completion,
    extract_generated_code,
    get_full_interface_signature,
    FIM_PREFIX, FIM_SUFFIX, FIM_MIDDLE,
)

import textwrap
import re

def _parse_grasp_guidance(grasp_guidance: str) -> List[Dict[str, Any]]:
    if not grasp_guidance:
        return []

    results = []

    object_pattern = r"-\s*['\"](\w+)['\"]\s*:"
    approach_pattern = r"approach_direction\s*=\s*['\"](\w+)['\"]"
    yaw_pattern = r"yaw_mode\s*=\s*['\"](\w+)['\"]"

    lines = grasp_guidance.split('\n')
    current_object = None
    current_block = ""

    for line in lines:
        obj_match = re.search(object_pattern, line)
        if obj_match:
            if current_object and current_block:
                approach_match = re.search(approach_pattern, current_block)
                yaw_match = re.search(yaw_pattern, current_block)

                results.append({
                    "object": current_object,
                    "approach_direction": approach_match.group(1) if approach_match else "down",
                    "yaw_mode": yaw_match.group(1) if yaw_match else "parallel",
                })

            current_object = obj_match.group(1)
            current_block = line
        elif current_object:
            current_block += " " + line

    if current_object and current_block:
        approach_match = re.search(approach_pattern, current_block)
        yaw_match = re.search(yaw_pattern, current_block)

        results.append({
            "object": current_object,
            "approach_direction": approach_match.group(1) if approach_match else "down",
            "yaw_mode": yaw_match.group(1) if yaw_match else "parallel",
        })

    print(f"[GRASP PARSE] Parsed {len(results)} grasp objects: {results}")
    return results


def _find_grasp_locations(code: str, grasp_objects: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    if not grasp_objects:
        return []

    results = []
    lines = code.split('\n')

    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        print(f"[GRASP FIND] Failed to parse code: {e}")
        return []

    grasp_obj_names = {obj["object"].lower(): obj for obj in grasp_objects}

    var_definitions = {}
    shape_var_map = {}

    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Name):
                    var_name = target.id
                    try:
                        val_str = ast.unparse(node.value) if hasattr(ast, 'unparse') else ""
                        var_definitions[var_name] = val_str.lower()

                        if isinstance(node.value, ast.Call):
                            if isinstance(node.value.func, ast.Name) and node.value.func.id == "Shape":
                                if node.value.args and isinstance(node.value.args[0], ast.Constant):
                                    shape_obj_name = node.value.args[0].value
                                    if isinstance(shape_obj_name, str):
                                        shape_var_map[shape_obj_name.lower()] = var_name
                                        print(f"[GRASP FIND] Found Shape variable: {var_name} = Shape('{shape_obj_name}')")
                    except Exception:
                        pass

    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            if isinstance(node.func, ast.Name) and node.func.id == "ur5_grasp_at":
                line_number = node.lineno
                line_content = lines[line_number - 1] if line_number <= len(lines) else ""

                indent = len(line_content) - len(line_content.lstrip())
                indent_str = line_content[:indent]

                grasp_pos_expr = None
                matched_object = None

                for kw in node.keywords:
                    if kw.arg == "grasp_pos":
                        grasp_pos_expr = ast.unparse(kw.value) if hasattr(ast, 'unparse') else str(kw.value)

                        for obj_name, obj_info in grasp_obj_names.items():
                            if obj_name in grasp_pos_expr.lower():
                                matched_object = obj_info
                                break

                        if matched_object is None and isinstance(kw.value, ast.Name):
                            var_name = kw.value.id
                            var_def = var_definitions.get(var_name, "")
                            for obj_name, obj_info in grasp_obj_names.items():
                                if obj_name in var_def:
                                    matched_object = obj_info
                                    print(f"[GRASP FIND] Found object '{obj_name}' via variable '{var_name}' definition: {var_def}")
                                    break
                        break

                if matched_object is None and grasp_objects:
                    line_lower = line_content.lower()
                    for obj_name, obj_info in grasp_obj_names.items():
                        if obj_name in line_lower:
                            matched_object = obj_info
                            break

                if matched_object is None and grasp_objects:
                    matched_object = grasp_objects[0]
                    print(f"[GRASP FIND] Fallback: using first grasp object '{matched_object['object']}' for ur5_grasp_at at line {line_number}")

                if matched_object:
                    obj_name = matched_object["object"]
                    shape_var_name = shape_var_map.get(obj_name.lower())

                    results.append({
                        "line_number": line_number,
                        "object": obj_name,
                        "shape_var_name": shape_var_name,
                        "grasp_pos_expr": grasp_pos_expr or "grasp_pos",
                        "approach_direction": matched_object.get("approach_direction", "down"),
                        "yaw_mode": matched_object.get("yaw_mode", "parallel"),
                        "indent": indent_str,
                        "original_line": line_content,
                    })
                    if shape_var_name:
                        print(f"[GRASP FIND] Found ur5_grasp_at at line {line_number} for object '{obj_name}' (Shape var: {shape_var_name})")
                    else:
                        print(f"[GRASP FIND] Found ur5_grasp_at at line {line_number} for object '{obj_name}' (Shape var not found, will use fallback)")

    print(f"[GRASP FIND] Found {len(results)} grasp locations")
    return results


def _build_grasp_alignment_guidance(location: Dict[str, Any], grasp_guidance: str) -> Dict[str, Any]:
    obj_name = location["object"]
    shape_var_name = location.get("shape_var_name")
    grasp_pos_expr = location["grasp_pos_expr"]
    approach_direction = location.get("approach_direction", "down")
    yaw_mode = location.get("yaw_mode", "parallel")

    interface = f"ur5_align_gripper(env, task, reference_quat=..., approach_direction='{approach_direction}', yaw_mode='{yaw_mode}')"

    content = f"Get the object's quaternion and align gripper before grasping."

    if shape_var_name:
        quat_get_expr = f"{shape_var_name}.get_quaternion()"
        quat_var_name = f"{shape_var_name}_quat"
        example_quat_line = f"  {quat_var_name} = {quat_get_expr}"
    else:
        quat_get_expr = f"Shape('{obj_name}').get_quaternion()"
        quat_var_name = f"{obj_name}_quat"
        example_quat_line = f"  {quat_var_name} = {quat_get_expr}"
        print(f"[GRASP GUIDANCE] Using fallback for '{obj_name}': Shape('{obj_name}').get_quaternion()")

    full_content = (
        f"Insert ur5_align_gripper before ur5_grasp_at:\n\n"
        f"IMPORTANT: First, get the object's quaternion using {quat_get_expr} and store it in a variable.\n"
        f"Then call ur5_align_gripper with that quaternion as reference_quat.\n\n"
        f"Example:\n"
        f"{example_quat_line}\n"
        f"  obs, reward, done = ur5_align_gripper(env, task, reference_quat={quat_var_name}, approach_direction='{approach_direction}', yaw_mode='{yaw_mode}')"
    )

    return {
        "line_number": location["line_number"],
        "object": obj_name,
        "shape_var_name": shape_var_name,
        "grasp_pos_expr": grasp_pos_expr,
        "approach_direction": approach_direction,
        "yaw_mode": yaw_mode,
        "indent": location["indent"],
        "interface": interface,
        "content": content,
        "full_content": full_content,
    }


def _run_grasp_prefix_infilling(
    code: str,
    grasp_alignment_guidance: Dict[str, Any],
    model: Any,
    tokenizer: Any,
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    target_robot: str = "ur5",
) -> Tuple[str, int]:
    line_number = grasp_alignment_guidance["line_number"]
    indent_str = grasp_alignment_guidance["indent"]
    interface = grasp_alignment_guidance["interface"]
    full_content = grasp_alignment_guidance["full_content"]
    grasp_pos_expr = grasp_alignment_guidance["grasp_pos_expr"]
    obj_name = grasp_alignment_guidance["object"]

    lines = code.split('\n')

    if line_number > len(lines):
        print(f"[GRASP PREFIX] Invalid line number {line_number}, skipping")
        return code, 0

    prefix_lines = lines[:line_number - 1]
    prefix_code = '\n'.join(prefix_lines)
    if prefix_code and not prefix_code.endswith('\n'):
        prefix_code += '\n'

    suffix_lines = lines[line_number - 1:]
    suffix_code = '\n'.join(suffix_lines)

    func_name_match = re.match(r'(\w+)\s*\(', interface.strip())
    if func_name_match:
        func_name = func_name_match.group(1)
        full_sig = get_full_interface_signature(func_name, target_robot)
        formatted_interface = f"{indent_str}# obs, reward, done = {full_sig}"
    else:
        formatted_interface = f"{indent_str}# obs, reward, done = {interface}"

    instruction_lines = full_content.split('\n')
    formatted_instruction = '\n'.join(f"# {line}" if line.strip() else "#" for line in instruction_lines)

    hint_for_FIM = f"""{indent_str}# Generate ur5_align_gripper call:
{formatted_interface}
{indent_str}"""

    fim_prompt = f"{formatted_instruction}\n{FIM_PREFIX}{prefix_code}{hint_for_FIM}{FIM_SUFFIX}\n{suffix_code}{FIM_MIDDLE}"

    print(f"[GRASP PREFIX] FIM prompt for ur5_align_gripper at line {line_number}:")
    print(f"[GRASP PREFIX] Interface hint:\n{hint_for_FIM}")

    generated_text, _, num_tokens = generate_fim_completion(
        model=model,
        tokenizer=tokenizer,
        fim_prompt=fim_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=1.0,
        seed=42,
        past_key_values=None,
        use_cache=False,
    )

    print(f"[GRASP PREFIX] Raw generated: {generated_text}")

    generated_code = extract_generated_code(generated_text)
    print(f"[GRASP PREFIX] Extracted code: {generated_code}")

    generated_lines = generated_code.strip().split('\n')
    indented_lines = []
    for gen_line in generated_lines:
        stripped = gen_line.strip()
        if stripped:
            if gen_line.startswith(indent_str) or gen_line.startswith(' ' * len(indent_str)):
                indented_lines.append(gen_line)
            else:
                indented_lines.append(indent_str + stripped)

    move_to_line = f"{indent_str}obs, reward, done = ur5_move_to(env, task, target_pos={grasp_pos_expr} + np.array([0, 0, 0.15]))"

    print(f"[GRASP PREFIX] Directly inserting ur5_move_to: {move_to_line}")

    new_lines = prefix_lines + [move_to_line] + indented_lines + suffix_lines
    updated_code = '\n'.join(new_lines)

    print(f"[GRASP PREFIX] Inserted ur5_move_to (direct) + {len(indented_lines)} line(s) from FIM before ur5_grasp_at")

    return updated_code, num_tokens


_INFILLING_MODEL = None
_INFILLING_TOKENIZER = None
_INFILLING_DEVICE = None


def init_infilling_llm(model_name: str = "Qwen/Qwen2.5-Coder-7B"):
    global _INFILLING_MODEL, _INFILLING_TOKENIZER, _INFILLING_DEVICE
    if _INFILLING_MODEL is None:
        print(f"[OURS - ONLINE] Loading infilling model: {model_name}")
        _INFILLING_MODEL, _INFILLING_TOKENIZER, _INFILLING_DEVICE = load_model_and_tokenizer(model_name)
        print(f"[OURS - ONLINE] Infilling model loaded on device: {_INFILLING_DEVICE}")


def get_infilling_model():
    return _INFILLING_MODEL, _INFILLING_TOKENIZER, _INFILLING_DEVICE

def extract_function_body(code_str, func_name):
    tree = ast.parse(code_str)
    lines = code_str.splitlines(keepends=True)

    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name == func_name:
            start = node.body[0].lineno - 1
            end = node.body[-1].end_lineno

            body = "".join(lines[start:end])
            return body

    raise ValueError(f"Function '{func_name}' not found")

def _apply_skill_substitutions_fallback(code: str, primitive_mappings: List[Dict[str, Any]]) -> str:
    import re

    result = code
    substituted = False

    for mapping in primitive_mappings:
        match_spec = mapping.get("match", {})
        func_spec = match_spec.get("func", {})

        if func_spec.get("type") != "Name":
            continue

        source_skill = func_spec.get("id")
        interface = mapping.get("interface", "")

        if not source_skill or not interface:
            continue

        target_match = re.match(r"(\w+)\s*\(", interface)
        if not target_match:
            continue

        target_skill = target_match.group(1)

        pattern = rf'\b{re.escape(source_skill)}\s*\('
        if re.search(pattern, result):
            param_mappings = mapping.get("parameter_mappings", [])
            param_map = {}
            for pm in param_mappings:
                src_param = pm.get("source")
                tgt_param = pm.get("target")
                if src_param and tgt_param:
                    if "." not in tgt_param:
                        param_map[src_param] = tgt_param

            old_call = f"{source_skill}("
            new_call = f"{target_skill}("
            if old_call in result:
                result = result.replace(old_call, new_call)
                substituted = True
                print(f"[FALLBACK] Substituted {source_skill} -> {target_skill}")

    if substituted:
        print(f"[FALLBACK] Applied skill name substitutions to code")

    return result


def get_initial_code(skill_name, source_robot, target_robot, reference_code, target_scene_info, grasp_guidance=None):
    import time
    _timings = {}
    _total_start = time.time()

    _step_start = time.time()
    lines = reference_code.splitlines()

    lines = [
        l for l in lines
        if not l.lstrip().startswith(("import ", "from "))
    ]

    filtered_code = []
    record = False
    for l in lines:
        if l.lstrip().startswith("def run_skill"):
            record = True
        if record:
            filtered_code.append(l)

    preprocessed_reference_code = "\n".join(filtered_code).strip() + "\n"
    _timings['preprocessing'] = time.time() - _step_start
    print(f"[TIMING] Preprocessing: {_timings['preprocessing']:.2f}s")

    _step_start = time.time()
    guidance = get_initial_guidance(
        source_robot=source_robot,
        target_robot=target_robot,
        reference_code=preprocessed_reference_code,
        target_scene_info=target_scene_info,
        grasp_guidance=None,
    )
    _timings['get_initial_guidance'] = time.time() - _step_start
    print(f"[TIMING] get_initial_guidance: {_timings['get_initial_guidance']:.2f}s")
    print(f"[OURS - ONLINE] Auto-generated structured guidance:\n{guidance}")

    _step_start = time.time()
    primitive_mapping_list = guidance.get('primitive_mappings', [])
    initial_guidance = {
        "skill_name": skill_name,
        "source_robot": source_robot,
        "target_robot": target_robot,
        "guidance": primitive_mapping_list
    }

    reference_code_path_str = f"./references/{source_robot}_{skill_name}.py"
    reference_code_path = Path(reference_code_path_str)
    reference_code_path.write_text(preprocessed_reference_code, encoding="utf-8")

    guidance_path_str = f"./guidances/{target_robot}_{skill_name}_initial.json"
    guidance_path = Path(guidance_path_str)
    guidance_path.write_text(json.dumps(initial_guidance, ensure_ascii=False, indent=2), encoding="utf-8")
    _timings['save_guidance_and_reference'] = time.time() - _step_start
    print(f"[TIMING] Save guidance and reference: {_timings['save_guidance_and_reference']:.2f}s")

    _step_start = time.time()
    model, tokenizer, device = get_infilling_model()
    _timings['get_infilling_model'] = time.time() - _step_start
    print(f"[TIMING] get_infilling_model: {_timings['get_infilling_model']:.2f}s")

    _step_start = time.time()
    infilled_code, num_tokens = run_full_infilling_pipeline(
        source_path=reference_code_path_str,
        guidance_path=guidance_path_str,
        model=model,
        tokenizer=tokenizer,
    )
    _timings['run_full_infilling_pipeline'] = time.time() - _step_start
    print(f"[TIMING] run_full_infilling_pipeline: {_timings['run_full_infilling_pipeline']:.2f}s")

    _step_start = time.time()
    if grasp_guidance and isinstance(grasp_guidance, str) and grasp_guidance.strip():
        print(f"[OURS - ONLINE] Processing grasp_guidance: {grasp_guidance[:100]}...")

        grasp_objects = _parse_grasp_guidance(grasp_guidance)

        if grasp_objects:
            grasp_locations = _find_grasp_locations(infilled_code, grasp_objects)

            grasp_locations_sorted = sorted(grasp_locations, key=lambda x: x["line_number"], reverse=True)

            for location in grasp_locations_sorted:
                print(f"[OURS - ONLINE] Processing grasp location at line {location['line_number']} for '{location['object']}'")

                grasp_alignment_guidance = _build_grasp_alignment_guidance(location, grasp_guidance)

                infilled_code, extra_tokens = _run_grasp_prefix_infilling(
                    infilled_code,
                    grasp_alignment_guidance,
                    model,
                    tokenizer,
                    target_robot=target_robot,
                )
                num_tokens += extra_tokens
                print(f"[OURS - ONLINE] Added grasp prefix for '{location['object']}' (tokens: {extra_tokens})")

    _timings['grasp_guidance_processing'] = time.time() - _step_start
    print(f"[TIMING] grasp_guidance_processing: {_timings['grasp_guidance_processing']:.2f}s")

    _step_start = time.time()
    initial_code_body = extract_function_body(infilled_code, "run_skill")
    _timings['extract_function_body'] = time.time() - _step_start
    print(f"[TIMING] extract_function_body: {_timings['extract_function_body']:.2f}s")

    _timings['total'] = time.time() - _total_start
    print(f"[INITIAL_CODE_START]\n{initial_code_body}\n[INITIAL_CODE_END]")
    print(f"[INITIAL_CODE] Generated tokens: {num_tokens}")

    print(f"\n[TIMING SUMMARY] get_initial_code total: {_timings['total']:.2f}s")
    print(f"  - Preprocessing: {_timings.get('preprocessing', 0):.2f}s")
    print(f"  - get_initial_guidance: {_timings.get('get_initial_guidance', 0):.2f}s")
    print(f"  - Save guidance and reference: {_timings.get('save_guidance_and_reference', 0):.2f}s")
    print(f"  - get_infilling_model: {_timings.get('get_infilling_model', 0):.2f}s")
    print(f"  - run_full_infilling_pipeline: {_timings.get('run_full_infilling_pipeline', 0):.2f}s")
    print(f"  - grasp_guidance_processing: {_timings.get('grasp_guidance_processing', 0):.2f}s")
    print(f"  - extract_function_body: {_timings.get('extract_function_body', 0):.2f}s")

    return initial_code_body, num_tokens

def get_repair_guidance(
    source_robot: str,
    target_robot: str,
    invalid_statements: List[Dict[str, Any]],
    current_code: str,
    executed_stmt_texts: List[str],
) -> Dict[str, Any]:
    subs = _load_skill_substitutions()
    src_skill_index = _load_robot_skills_as_dict(source_robot)
    tgt_skill_index = _load_robot_skills_as_dict(target_robot)

    repair_mappings: List[Dict[str, Any]] = []

    for inv_stmt in invalid_statements:
        line_number = inv_stmt.get("line_number")
        stmt_code = inv_stmt.get("statement_code", "")
        skill_name = inv_stmt.get("skill_name", "")
        violations = inv_stmt.get("violations", [])
        precond_violations = inv_stmt.get("precondition_violations", [])
        postcond_violations = inv_stmt.get("postcondition_violations", [])
        warnings = inv_stmt.get("warnings", [])
        projected_state = inv_stmt.get("projected_state", {})

        if skill_name not in tgt_skill_index:
            eq = _best_equivalence(subs, source_robot, target_robot, skill_name)
            if eq:
                tgt_full = eq["skill_b"]
                _, tgt_skill_name = tgt_full.split(":", 1)
                if tgt_skill_name in tgt_skill_index:
                    skill_name = tgt_skill_name

        tgt_skill_spec = tgt_skill_index.get(skill_name, {})

        violation_desc = []
        if precond_violations:
            violation_desc.append(f"Precondition violations: {', '.join(precond_violations)}")
        if postcond_violations:
            violation_desc.append(f"Postcondition violations: {', '.join(postcond_violations)}")
        if violations:
            violation_desc.append(f"Violations: {', '.join(violations)}")
        if warnings:
            violation_desc.append(f"Warnings: {', '.join(warnings)}")

        state_desc = []
        if projected_state.get("gripper_open") is not None:
            state_desc.append(f"gripper_open={projected_state['gripper_open']}")
        if projected_state.get("held_object") is not None:
            state_desc.append(f"held_object={projected_state['held_object']}")
        if projected_state.get("gripper_pos") is not None:
            state_desc.append(f"gripper_pos={projected_state['gripper_pos']}")

        content = f"Fix the `{skill_name}` call that violates conditions."
        if violation_desc:
            content += " " + "; ".join(violation_desc) + "."
        if state_desc:
            content += f" Projected state: {', '.join(state_desc)}."

        full_content = (
            f"The statement `{stmt_code.strip()}` at line {line_number} is invalid.\n\n"
            f"Skill: {skill_name}\n"
        )
        if violation_desc:
            full_content += "Issues:\n" + "\n".join([f"  - {v}" for v in violation_desc]) + "\n"
        if state_desc:
            full_content += f"\nProjected State: {', '.join(state_desc)}\n"

        interface_example = ""
        if tgt_skill_spec:
            interface_example = generate_interface_example(tgt_skill_spec, [])

        match_spec = {
            "node": "call",
            "func": {"type": "Name", "id": skill_name},
            "keywords_required": [],
        }

        repair_mappings.append({
            "line_start": line_number,
            "line_end": line_number,
            "match": match_spec,
            "type": "repair",
            "content": content,
            "full_content": full_content,
            "interface": interface_example,
            "violations": violations,
            "precondition_violations": precond_violations,
            "postcondition_violations": postcond_violations,
            "warnings": warnings,
            "projected_state": projected_state,
            "original_statement": stmt_code,
        })

    return {
        "source_robot": source_robot,
        "target_robot": target_robot,
        "repair_mappings": repair_mappings,
        "executed_stmt_count": len(executed_stmt_texts),
        "meta": {
            "total_invalid_statements": len(invalid_statements),
            "total_repair_mappings": len(repair_mappings),
        },
    }


def get_repaired_statements(
    skill_name: str,
    source_robot: str,
    target_robot: str,
    current_code: str,
    executed_stmt_texts: List[str],
    invalid_statements: List[Dict[str, Any]],
    scene_info: Optional[List[Dict[str, str]]] = None,
) -> str:
    if not invalid_statements:
        print("[REPAIR] No invalid statements to repair")
        return current_code

    print(f"[REPAIR] Processing {len(invalid_statements)} invalid statement(s)")
    print(f"[REPAIR] Executed statements count: {len(executed_stmt_texts)}")

    repair_guidance = get_repair_guidance(
        source_robot=source_robot,
        target_robot=target_robot,
        invalid_statements=invalid_statements,
        current_code=current_code,
        executed_stmt_texts=executed_stmt_texts,
    )

    guidance_for_infilling = {
        "skill_name": skill_name,
        "source_robot": source_robot,
        "target_robot": target_robot,
        "guidance": repair_guidance.get("repair_mappings", []),
    }

    if not guidance_for_infilling["guidance"]:
        print("[REPAIR] No repair mappings generated")
        return current_code

    repair_code_dir = Path("repair_temp")
    repair_code_dir.mkdir(parents=True, exist_ok=True)

    repair_code_path = repair_code_dir / f"{target_robot}_{skill_name}_repair_source.py"
    repair_code_path.write_text(current_code, encoding="utf-8")

    guidance_dir = Path("guidances")
    guidance_dir.mkdir(parents=True, exist_ok=True)

    guidance_path = guidance_dir / f"{target_robot}_{skill_name}_repair.json"
    guidance_path.write_text(
        json.dumps(guidance_for_infilling, ensure_ascii=False, indent=2),
        encoding="utf-8"
    )

    print(f"[REPAIR] Saved repair source to: {repair_code_path}")
    print(f"[REPAIR] Saved repair guidance to: {guidance_path}")

    try:
        model, tokenizer, device = get_infilling_model()
        repaired_code, num_tokens = run_full_infilling_pipeline(
            source_path=str(repair_code_path),
            guidance_path=str(guidance_path),
            model=model,
            tokenizer=tokenizer,
        )

        try:
            repaired_body = extract_function_body(repaired_code, "run_skill")
            print(f"[REPAIR_CODE_START]\n{repaired_body}\n[REPAIR_CODE_END]")
        except ValueError:
            print("[REPAIR] Could not extract run_skill function body for logging")

        print(f"[REPAIR] Infilling completed successfully (tokens: {num_tokens})")
        return repaired_code, num_tokens

    except Exception as e:
        print(f"[REPAIR] Infilling failed: {e}")
        import traceback
        traceback.print_exc()
        return current_code, 0


def build_scene_description(
    scene: Dict[str, Any],
    gripper_state: Optional[Dict[str, Any]] = None,
    executed_stmt_texts: Optional[List[str]] = None,
) -> str:
    desc_parts = []

    gripper = scene.get("gripper", {})
    gripper_pos = gripper.get("position")
    is_open = gripper.get("is_open", True)

    gripper_desc = "# Current gripper state:\n"
    if gripper_pos:
        gripper_desc += f"#   Position: [{gripper_pos[0]:.3f}, {gripper_pos[1]:.3f}, {gripper_pos[2]:.3f}]\n"
    gripper_desc += f"#   Gripper is {'OPEN' if is_open else 'CLOSED'}\n"

    if gripper_state:
        held_obj = gripper_state.get("held_object")
        if held_obj:
            gripper_desc += f"#   Currently holding: '{held_obj}'\n"
        elif not is_open:
            gripper_desc += f"#   Gripper closed but not holding any object\n"

    desc_parts.append(gripper_desc)

    objects = scene.get("objects", [])
    if objects:
        obj_desc = "# Current object positions:\n"
        for obj in objects:
            name = obj.get("name", "unknown")
            obj_type = obj.get("type", "unknown")
            pos = obj.get("position")
            if pos:
                obj_desc += f"#   {name} ({obj_type}): [{pos[0]:.3f}, {pos[1]:.3f}, {pos[2]:.3f}]\n"
            else:
                obj_desc += f"#   {name} ({obj_type}): position unknown\n"
        desc_parts.append(obj_desc)

    if executed_stmt_texts:
        exec_desc = f"# Already executed {len(executed_stmt_texts)} statement(s):\n"
        for i, stmt in enumerate(executed_stmt_texts[-5:], start=max(0, len(executed_stmt_texts)-5)+1):
            stmt_short = stmt.strip().replace('\n', ' ')[:80]
            exec_desc += f"#   [{i}] {stmt_short}{'...' if len(stmt.strip()) > 80 else ''}\n"
        if len(executed_stmt_texts) > 5:
            exec_desc = f"# ... ({len(executed_stmt_texts) - 5} earlier statements omitted)\n" + exec_desc.split('\n', 1)[1]
        desc_parts.append(exec_desc)

    return "\n".join(desc_parts)


def get_revise_prompt(
    task_name: str,
    target_robot: str,
    scene: Dict[str, Any],
    executed_stmt_texts: List[str],
    last_step_payload: Dict[str, Any],
    object_names: List[str],
) -> str:
    obj_list_str = ", ".join(f"'{name}'" for name in object_names)

    gripper_state = last_step_payload.get("confirmed_state", {})

    scene_desc = build_scene_description(
        scene=scene,
        gripper_state=gripper_state,
        executed_stmt_texts=executed_stmt_texts,
    )

    if target_robot.lower() == "ur5":
        skill_list = "ur5_move_to, ur5_grasp_at, ur5_release_at, ur5_align_gripper"
        skill_examples = """
# Primitive call shapes (reference)
# ur5_grasp_at:  obs, reward, done = ur5_grasp_at(env, task, grasp_pos=[x,y,z], approach={'axis':'z', 'distance':FLOAT}, timeout_s=FLOAT)
# ur5_release_at: obs, reward, done = ur5_release_at(env, task, place_pos=[x,y,z], approach={'axis':'z', 'distance':FLOAT}, timeout_s=FLOAT)
# ur5_move_to:  obs, reward, done = ur5_move_to(env, task, target_pos=[x,y,z], timeout_s=FLOAT)
# ur5_align_gripper:
#   q = obj.get_quaternion()
#   obs, reward, done = ur5_align_gripper(env, task, reference_quat=q, yaw_mode='parallel', approach_direction=APPROACH_DIR, timeout_s=FLOAT)
"""
    elif target_robot.lower() == "panda":
        skill_list = "move, pick, place, push, open_gripper, close_gripper, align_to_quaternion, align_two_axes"
        skill_examples = """
# Primitive call shapes (reference)
# pick:  obs, reward, done = pick(env, task, target_pos=[x,y,z], approach_axis='z', approach_distance=FLOAT, timeout=FLOAT)
# place: obs, reward, done = place(env, task, target_pos=[x,y,z], approach_axis='z', approach_distance=FLOAT, timeout=FLOAT)
# move:  obs, reward, done = move(env, task, target_pos=[x,y,z], timeout=FLOAT)
# push:  obs, reward, done = push(env, task, target_pos=[x,y,z], approach_axis='z', approach_distance=FLOAT, timeout=FLOAT)
# grip:  obs, reward, done = open_gripper(env, task) / close_gripper(env, task)
# align:
#   q = obj.get_quaternion()
#   obs, reward, done = align_to_quaternion(env, task, quaternion=q, yaw_align='parallel', approach_dir=APPROACH_DIR, timeout=FLOAT)
"""
    elif target_robot.lower() == "sawyer":
        skill_list = "sawyer_move_to, sawyer_pick, sawyer_place, sawyer_align_gripper, sawyer_open_gripper, sawyer_close_gripper"
        skill_examples = """
# Primitive call shapes (reference)
# sawyer_pick:  obs, reward, done = sawyer_pick(env, task, target_object=obj, target_pos=[x,y,z], grasp_offset=[dx,dy,dz])
# sawyer_place: obs, reward, done = sawyer_place(env, task, place_pos=[x,y,z], place_offset=[dx,dy,dz])
# sawyer_move_to: obs, reward, done = sawyer_move_to(env, task, target_pos=[x,y,z], timeout_s=FLOAT)
# sawyer_align_gripper:
#   q = obj.get_quaternion()
#   obs, reward, done = sawyer_align_gripper(env, task, approach_direction='down', reference_quat=q, yaw_mode='parallel', timeout_s=FLOAT)
# sawyer_open_gripper:  obs, reward, done = sawyer_open_gripper(env, task, amount=1.0, velocity=0.2)
# sawyer_close_gripper: obs, reward, done = sawyer_close_gripper(env, task, amount=0.0, velocity=0.2)
"""

    prompt = f"""# TASK INCOMPLETE - NEED MORE STEPS
# Task: {task_name}
# The task is NOT yet complete. Generate additional code to finish the task.

# Available objects: {obj_list_str}
# Available primitive skills: {skill_list}
# OUTPUT: Only write valid Python code (continuation of run_skill body). No markdown. No comments except guidance.

{scene_desc}

# IMPORTANT: The following statements have already been executed successfully.
# DO NOT repeat them. Generate ONLY the remaining steps needed to complete the task.

{skill_examples}

# Continue from here (add the remaining steps):
"""
    return prompt


def get_repair_prompt(
    task_name: str,
    target_robot: str,
    current_code: str,
    executed_stmt_texts: List[str],
    failed_stmt_text: str,
    error_payload: Dict[str, Any],
    scene: Dict[str, Any],
    object_names: List[str],
    failed_line_number: Optional[int] = None,
) -> str:
    obj_list_str = ", ".join(f"'{name}'" for name in object_names)

    error_msg = error_payload.get("error", "Unknown error")
    error_type = error_payload.get("error_type", "Exception")

    scene_desc = build_scene_description(
        scene=scene,
        executed_stmt_texts=executed_stmt_texts,
    )

    if target_robot.lower() == "ur5":
        skill_list = "ur5_move_to, ur5_grasp_at, ur5_release_at, ur5_align_gripper"
    else:
        skill_list = "move, pick, place, push, open_gripper, close_gripper, align_to_quaternion, align_two_axes"

    line_info = ""
    if failed_line_number:
        line_info = f" (line {failed_line_number})"

    prompt = f"""# EXECUTION FAILURE - REPAIR NEEDED
# Task: {task_name}
# An error occurred during execution. Fix the code to handle this error.

# Available objects: {obj_list_str}
# Available primitive skills: {skill_list}
# OUTPUT: Only write valid Python code for the ENTIRE run_skill body. No markdown.

# ERROR INFORMATION:
# Error Type: {error_type}
# Error Message: {error_msg}
# Failed Statement{line_info}: {failed_stmt_text.strip()}

{scene_desc}

# CONSTRAINTS:
# 1. Keep the first {len(executed_stmt_texts)} statement(s) UNCHANGED (already executed successfully)
# 2. Fix or replace the failed statement and subsequent code
# 3. The fix should address the error while still achieving the task goal

# Already executed statements (DO NOT MODIFY):
"""
    for i, stmt in enumerate(executed_stmt_texts, start=1):
        stmt_short = stmt.strip().replace('\n', ' ')[:100]
        prompt += f"# [{i}] {stmt_short}\n"

    prompt += f"""
# Failed statement to fix:
# {failed_stmt_text.strip()}

# Generate the complete run_skill body with the fix:
"""
    return prompt


def get_additional_code(
    skill_name: str,
    source_robot: str,
    target_robot: str,
    current_code: str,
    executed_stmt_texts: List[str],
    scene: Dict[str, Any],
    last_step_payload: Dict[str, Any],
) -> str:
    print(f"[REVISE] Generating additional code for incomplete task")
    print(f"[REVISE] Executed statements: {len(executed_stmt_texts)}")

    object_names = [obj.get("name", "") for obj in scene.get("objects", []) if obj.get("name")]

    prompt = get_revise_prompt(
        task_name=skill_name,
        target_robot=target_robot,
        scene=scene,
        executed_stmt_texts=executed_stmt_texts,
        last_step_payload=last_step_payload,
        object_names=object_names,
    )

    print(f"[REVISE] Generated prompt:\n{prompt[:500]}...")

    try:
        revise_dir = Path("revise_temp")
        revise_dir.mkdir(parents=True, exist_ok=True)

        guidance_for_revise = {
            "skill_name": skill_name,
            "source_robot": source_robot,
            "target_robot": target_robot,
            "guidance": [{
                "line_start": None,
                "line_end": None,
                "match": {"node": "append"},
                "type": "revise",
                "content": prompt,
                "interface": "",
            }],
        }

        model, tokenizer, device = get_infilling_model()
        if model is None:
            model, tokenizer, device = load_model_and_tokenizer()

        fim_prompt = f"{prompt}\n{FIM_PREFIX}{current_code}\n    # Continue here:\n    {FIM_SUFFIX}\n{FIM_MIDDLE}"

        generated_text, _, num_tokens = generate_fim_completion(
            model=model,
            tokenizer=tokenizer,
            fim_prompt=fim_prompt,
            max_new_tokens=512,
            temperature=0.0,
        )

        print(f"[REVISE] Generated additional code:\n{generated_text[:300]}... (tokens: {num_tokens})")

        lines = current_code.split('\n')

        tree = ast.parse(current_code)
        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and node.name == "run_skill":
                last_line = node.end_lineno

                new_lines = []
                for line in generated_text.strip().split('\n'):
                    if line.strip():
                        new_lines.append("    " + line.strip())

                lines = lines[:last_line] + new_lines + lines[last_line:]
                break

        new_code = '\n'.join(lines)
        print(f"[REVISE] Code updated successfully")
        return new_code, num_tokens

    except Exception as e:
        print(f"[REVISE] Failed to generate additional code: {e}")
        import traceback
        traceback.print_exc()
        return current_code, 0


def get_repaired_code_on_failure(
    skill_name: str,
    source_robot: str,
    target_robot: str,
    current_code: str,
    executed_stmt_texts: List[str],
    failed_stmt_text: str,
    error_payload: Dict[str, Any],
    scene: Dict[str, Any],
) -> str:
    print(f"[REPAIR_FAILURE] Repairing code after failure")
    print(f"[REPAIR_FAILURE] Failed statement: {failed_stmt_text[:100]}...")
    print(f"[REPAIR_FAILURE] Error: {error_payload}")

    failed_line_number = None
    try:
        lines = current_code.split('\n')
        failed_stmt_stripped = failed_stmt_text.strip()
        for i, line in enumerate(lines, start=1):
            if failed_stmt_stripped in line or line.strip() == failed_stmt_stripped:
                failed_line_number = i
                break
    except Exception:
        pass

    object_names = [obj.get("name", "") for obj in scene.get("objects", []) if obj.get("name")]

    prompt = get_repair_prompt(
        task_name=skill_name,
        target_robot=target_robot,
        current_code=current_code,
        executed_stmt_texts=executed_stmt_texts,
        failed_stmt_text=failed_stmt_text,
        error_payload=error_payload,
        scene=scene,
        object_names=object_names,
        failed_line_number=failed_line_number,
    )

    print(f"[REPAIR_FAILURE] Generated prompt:\n{prompt[:500]}...")

    try:
        repair_dir = Path("repair_temp")
        repair_dir.mkdir(parents=True, exist_ok=True)

        repair_guidance = {
            "skill_name": skill_name,
            "source_robot": source_robot,
            "target_robot": target_robot,
            "guidance": [{
                "line_start": failed_line_number,
                "line_end": failed_line_number,
                "match": {"node": "repair_failure"},
                "type": "repair",
                "content": prompt,
                "interface": "",
                "error_info": error_payload,
            }],
        }

        model, tokenizer, device = get_infilling_model()
        if model is None:
            model, tokenizer, device = load_model_and_tokenizer()

        lines = current_code.split('\n')

        if failed_line_number:
            prefix_lines = lines[:failed_line_number - 1]
            prefix_code = '\n'.join(prefix_lines) + '\n'
        else:
            prefix_code = current_code

        fim_prompt = f"{prompt}\n{FIM_PREFIX}{prefix_code}    # Fix the following error and complete the task:\n    # Error: {error_payload.get('error', 'Unknown')}\n    {FIM_SUFFIX}\n{FIM_MIDDLE}"

        generated_text, _, num_tokens = generate_fim_completion(
            model=model,
            tokenizer=tokenizer,
            fim_prompt=fim_prompt,
            max_new_tokens=512,
            temperature=0.0,
        )

        print(f"[REPAIR_FAILURE] Generated repair code:\n{generated_text[:300]}... (tokens: {num_tokens})")

        new_lines = []
        for line in generated_text.strip().split('\n'):
            if line.strip():
                new_lines.append("    " + line.strip())

        repaired_lines = prefix_lines + new_lines
        new_code = '\n'.join(repaired_lines)

        try:
            ast.parse(new_code)
            print(f"[REPAIR_FAILURE] Code repaired successfully")
            return new_code, num_tokens
        except SyntaxError as e:
            print(f"[REPAIR_FAILURE] Generated code has syntax error: {e}")
            return current_code, 0

    except Exception as e:
        print(f"[REPAIR_FAILURE] Failed to repair code: {e}")
        import traceback
        traceback.print_exc()
        return current_code, 0
