# run_skeleton_task.py (Filled-in version)

import os
import re
import sys
import time
import tempfile
import subprocess
from typing import List, Set

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *          # noqa: F401 – we really need the star import for the skills
from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
#                      Fast-Downward / PDDL Helper Routines                   #
# --------------------------------------------------------------------------- #
_PDDL_KEYWORDS = {
    "and",
    "not",
    "forall",
    "when",
    "exists",
    "=",
    "imply",
    "or",
}


def _collect_defined_predicates(domain_txt: str) -> Set[str]:
    """Return a set with all predicates defined in (:predicates …)."""
    result: Set[str] = set()
    # Strip comments
    domain_txt = re.sub(r";.*", "", domain_txt)
    predicate_block = re.search(r"\(:predicates(.*?)\)", domain_txt, re.S)
    if not predicate_block:
        return result
    content = predicate_block.group(1)
    for line in re.findall(r"\([^)]+\)", content):
        parts = re.split(r"[()\s]+", line)
        parts = [p for p in parts if p and p[0] != "?" and p not in _PDDL_KEYWORDS]
        if parts:
            result.add(parts[0])
    return result


def _collect_used_predicates(domain_txt: str) -> Set[str]:
    """Return every predicate *mentioned* anywhere in the domain."""
    result: Set[str] = set()
    domain_txt = re.sub(r";.*", "", domain_txt)
    for token in re.findall(r"\([^)]+\)", domain_txt):
        parts = re.split(r"[()\s]+", token)
        for p in parts:
            if p and p[0] != "?" and p not in _PDDL_KEYWORDS:
                result.add(p)
    return result


def _detect_missing_predicates(domain_txt: str) -> List[str]:
    defined = _collect_defined_predicates(domain_txt)
    used = _collect_used_predicates(domain_txt)
    missing = sorted(used - defined)
    return missing


def _augment_domain_with_predicates(domain_txt: str, new_preds: List[str]) -> str:
    """Insert new predicate stubs into the existing (:predicates …) list."""
    if not new_preds:
        return domain_txt

    stub_lines = ["  " + f"({p} ?x - object)" for p in new_preds]

    def _insertion(m):
        return m.group(0) + "\n" + "\n".join(stub_lines)

    augmented = re.sub(r"\(:predicates", _insertion, domain_txt, count=1, flags=re.S)
    return augmented


def _call_planner(domain_path: str, problem_path: str, plan_path: str, timeout: int = 40):
    """Run Fast-Downward once, catching timeouts."""
    cmd = [
        "INPUT_YOUR_PATH",
        "--alias",
        "seq-sat-lama-2011",
        "--plan-file",
        plan_path,
        domain_path,
        problem_path,
    ]
    try:
        subprocess.run(cmd, check=True, timeout=timeout, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    except subprocess.TimeoutExpired:
        print("[Planner] Planner timed-out.")
        return False
    except subprocess.CalledProcessError:
        print("[Planner] Planner reported an error.")
        return False


def _ensure_domain_predicates(domain_src: str, problem_src: str) -> str:
    """
    Make sure the domain declares every predicate it uses.
    If a predicate is missing, create a patched temporary domain file that
    includes generic object-typed stubs for each missing predicate.
    Return a path to the domain file we should actually feed the planner with.
    """
    with open(domain_src, "r") as f:
        domain_txt = f.read()

    missing = _detect_missing_predicates(domain_txt)
    if not missing:
        return domain_src

    print("[Exploration] Detected missing predicates:", ", ".join(missing))
    patched_domain = _augment_domain_with_predicates(domain_txt, missing)

    tmp_fd, tmp_path = tempfile.mkstemp(suffix=".pddl", text=True)
    os.close(tmp_fd)
    with open(tmp_path, "w") as f:
        f.write(patched_domain)

    # Verify the patched file compiles quickly (without solving, short timeout)
    _call_planner(tmp_path, problem_src, plan_path=tempfile.mktemp(), timeout=5)

    print("[Exploration] Patched domain written to:", tmp_path)
    return tmp_path


# --------------------------------------------------------------------------- #
#                               Main Task Routine                             #
# --------------------------------------------------------------------------- #
def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    # --- Path discovery for domain/problem (only if running a planner) ---- #
    # The feedback gave us concrete paths.  If they ever change, fall back
    # to environment variables or a simple search heuristic.
    domain_path = os.getenv("INPUT_YOUR_PATH")
    problem_path = os.getenv("INPUT_YOUR_PATH")

    try:
        patched_domain = _ensure_domain_predicates(domain_path, problem_path)
    except FileNotFoundError:
        # In pure RLBench evaluation the PDDL files may not exist; skip.
        patched_domain = None

    # ---------------------------------------------------------------------- #
    #                        RLBench / Skill Execution                       #
    # ---------------------------------------------------------------------- #
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)

        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        positions = get_object_positions()
        print("[Scene] Known objects:", list(positions.keys()))

        # ------------------------------------------------------------------ #
        # Example HIGH-LEVEL strategy:
        # 1.  If a drawer exists, rotate gripper, move to side, anchor, grasp,
        #     and pull.  Otherwise, iterate through all objects and simply
        #     pick & place them back.  This code path exercises every
        #     available primitive so that integration tests can validate the
        #     bindings without requiring a single “correct” task.
        # ------------------------------------------------------------------ #
        drawer_name = next((n for n in positions if "drawer" in n.lower()), None)
        gripper_name = "gripper"  # skill_code primitives usually infer the robot gripper internally

        if drawer_name:
            drawer_pos = positions[drawer_name]
            print(f"[Task] Handling drawer <{drawer_name}> at {drawer_pos}")

            # 1-a) rotate 0 → ninety_deg
            try:
                rotate(env, task, gripper_name, "zero_deg", "ninety_deg")
            except Exception as e:
                print("  [Warn] rotate failed:", e)

            # 1-b) move to side, anchor, grasp, pull
            side_pos = f"{drawer_name}_side"
            anchor_pos = f"{drawer_name}_anchor"
            try:
                move_to_side = globals().get("move_to_side")  # might not exist
                if move_to_side:
                    move_to_side(env, task, gripper_name, drawer_name, "home", side_pos)
                move_to_anchor = globals().get("move_to_anchor")
                if move_to_anchor:
                    move_to_anchor(env, task, gripper_name, drawer_name, side_pos, anchor_pos)
                pick_drawer = globals().get("pick_drawer") or globals().get("pick-drawer")
                if pick_drawer:
                    pick_drawer(env, task, gripper_name, drawer_name, anchor_pos)
                pull(env, task, gripper_name, drawer_name)
            except Exception as e:
                print("  [Warn] Drawer routine failed:", e)
        else:
            # Fallback: manipulate any graspable objects
            for obj_name, pos in positions.items():
                if "drawer" in obj_name.lower():
                    continue
                try:
                    print(f"[Task] pick-&-place object {obj_name}")
                    pick(env, task, target_pos=pos)
                    place(env, task, target_pos=pos)  # put it right back
                except Exception as e:
                    print(f"  [Warn] Could not manipulate {obj_name}:", e)

        print("[Task] Finished scripted interaction.")

    finally:
        shutdown_environment(env)

    # ---------------------------------------------------------------------- #
    # (Optional)  Run the planner on patched domain/problem if present
    # ---------------------------------------------------------------------- #
    if patched_domain:
        plan_fd, plan_path = tempfile.mkstemp(text=True)
        os.close(plan_fd)
        solved = _call_planner(patched_domain, problem_path, plan_path, timeout=40)
        if solved and os.path.exists(plan_path):
            with open(plan_path) as f:
                plan_lines = [ln.strip() for ln in f if ln.strip()]
            print("[Planner] Plan found:")
            for ln in plan_lines:
                print("  ", ln)
        else:
            print("[Planner] No plan could be found (or planner failed).")

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()