# run_skeleton_task.py (Completed)

import os
import re
import time
import numpy as np

# RLBench / simulation utilities (left untouched – assumed to be available)
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor
from env import setup_environment, shutdown_environment

# Low-level skills already implemented elsewhere
from skill_code import *          # noqa: F403,F401  (we just import; we will NOT redefine)

# Video helpers (recording can be disabled if not needed)
from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)

# Utility that returns the current object positions in the scene
from object_positions import get_object_positions


########################################################################
# Helper ❶ :  PDDL domain-file analysis – find predicates that
#             are referenced in action definitions but never declared.
########################################################################
_PRED_DECL_RE = re.compile(r"^\s*\(\s*[\w-]+")          # “(predicate …”
_ACTION_RE    = re.compile(r"^\s*\(:action\s+(\w+)")    # “(:action …”
_TOKEN_RE     = re.compile(r"[A-Za-z0-9_-]+")


def _tokenize_pddl_line(line: str):
    """Return every raw token inside the string (parentheses ignored)."""
    return _TOKEN_RE.findall(line)


def _collect_predicates_from_action(action_lines):
    """Return every predicate symbol that appears inside the block of lines
    belonging to the current action.
    We treat every ‘( name …’ pattern as the reference to a predicate.
    """
    referenced = set()
    for l in action_lines:
        for part in re.finditer(r"\([^()]*\)", l):
            # remove front ‘(’ then split
            tk = _TOKEN_RE.findall(part.group())
            if tk:
                referenced.add(tk[0])
    return referenced


def find_undefined_predicates(domain_path: str):
    """
    Parse a PDDL domain file and return a set of predicate symbols that are
    referenced in :precondition / :effect sections of actions but never appear
    in the (:predicates …) declaration.
    """
    defined_preds = set()
    referenced_preds = set()

    if not os.path.exists(domain_path):
        print(f"[Warning] Domain file not found: {domain_path}")
        return set()

    with open(domain_path, "r") as f:
        lines = f.readlines()

    inside_pred_section = False
    inside_action_block = False
    current_action_lines = []

    for raw in lines:
        line = raw.strip()

        # Start / End of (:predicates …)  ---------------------------------
        if line.lower().startswith("(:predicates"):
            inside_pred_section = True
            # skip the very line containing ‘(:predicates’
            line = line[len("(:predicates") :]
        if inside_pred_section:
            # Search for predicate symbols on the current line
            while "(" in line:
                l_par = line.find("(")
                r_par = line.find(")", l_par)
                if r_par == -1:
                    break
                seg = line[l_par : r_par + 1]
                tokens = _tokenize_pddl_line(seg)
                if tokens:
                    defined_preds.add(tokens[0])
                line = line[r_par + 1 :]
            # End?
            if ")" in raw:
                # naive check – works because PDDL has balanced parenthesis
                if raw.count("(") <= raw.count(")"):
                    inside_pred_section = False
            continue  # done with predicates parsing

        # Collect :action blocks -------------------------------------------
        if _ACTION_RE.match(line):
            # Starting a new action – finish the previous block first
            if current_action_lines:
                referenced_preds |= _collect_predicates_from_action(
                    current_action_lines
                )
                current_action_lines = []
            inside_action_block = True

        if inside_action_block:
            current_action_lines.append(raw)
            # End of action block if ')' on its own line
            if raw.strip() == ")":
                inside_action_block = False
                # harvest predicates from this action
                referenced_preds |= _collect_predicates_from_action(
                    current_action_lines
                )
                current_action_lines = []

    # Edge case: last action may not have been flushed
    if current_action_lines:
        referenced_preds |= _collect_predicates_from_action(current_action_lines)

    undefined = referenced_preds - defined_preds
    return undefined


########################################################################
# Main entry
########################################################################
def run_skeleton_task():
    """
    1)  Identify missing predicate symbols in the domain (compensates for the
        validation timeout we observed previously – very likely caused by the
        planner stumbling over an undeclared predicate).
    2)  Demonstrate a minimal exploration routine that *could* be used to
        ground such predicates (here we simply iterate available skills in a
        safe manner, without assuming any specific task layout).
    """
    print("\n===== Starting Skeleton Task =====\n")

    # ------------------------------------------------------------------
    #  Step-0:   Static domain analysis
    # ------------------------------------------------------------------
    domain_file_path = os.getenv("DOMAIN_PDDL_PATH", "domain.pddl")
    missing_preds = find_undefined_predicates(domain_file_path)

    if not missing_preds:
        print("[Domain-Check] No undefined predicate symbols found.")
    else:
        print("[Domain-Check] Undefined predicate symbols detected:")
        for p in missing_preds:
            print(f"  • {p}")
        print("--------------------------------------------------")
        print(" !!  The most likely culprit causing the planner  ")
        print(" !!  to time-out is the missing predicate above.  ")
        print("--------------------------------------------------\n")

    # ------------------------------------------------------------------
    #  Step-1:   Environment boot-up
    # ------------------------------------------------------------------
    try:
        env, task = setup_environment()
    except Exception as ex:
        print(f"[Error] Environment could not be started: {ex}")
        return

    # Make sure we always shut down the simulator
    try:
        # --------------------------------------------------------------
        #  Step-2:   Reset task, start optional video recording
        # --------------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)                # safe even if video turned off
        task.step = recording_step(task.step)  # wrap for recording

        # --------------------------------------------------------------
        #  Step-3:   Retrieve basic scene information
        # --------------------------------------------------------------
        positions = {}
        try:
            positions = get_object_positions()     # expected dict
            print(f"[Info] Known object positions: {list(positions.keys())}")
        except Exception:
            print("[Warning] Could not fetch object positions – proceeding "
                  "with generic exploration.")

        # --------------------------------------------------------------
        #  Step-4:   Minimal exploration loop
        # --------------------------------------------------------------
        #
        # We iterate over (at most) one object – the reason is simply to show
        # how the declared low-level skills may be called.  The exploration
        # is deliberately conservative to avoid collisions in an unknown scene.
        #
        explored_object_names = list(positions.keys())[:1] if positions else []

        for obj_name in explored_object_names:
            target_pos = np.array(positions[obj_name])

            # ---------- 4-1)  Move gripper above the object  ----------
            safe_height = 0.20           # 20-cm above the table
            approach = target_pos.copy()
            approach[2] += safe_height

            print(f"[Exploration] Moving above '{obj_name}' …")
            try:
                obs, reward, done = move(                     # noqa: F405
                    env,
                    task,
                    target_pos=approach,
                    max_steps=150,
                    threshold=0.01,
                    speed_ratio=1.0,
                    timeout=8.0,
                )
                if done:
                    print("[Exploration] Task ended unexpectedly.")
                    break
            except Exception as ex:
                print(f"[Exploration-Warning] move() failed – {ex}")
                continue

            # ---------- 4-2)  Attempt a gentle pick  ----------
            print(f"[Exploration] Attempting pick on '{obj_name}' …")
            try:
                obs, reward, done = pick(                     # noqa: F405
                    env,
                    task,
                    target_pos=target_pos,
                    approach_distance=0.15,
                    max_steps=120,
                    threshold=0.01,
                    approach_axis="z",
                    timeout=6.0,
                )
                if done:
                    print("[Exploration] Task ended unexpectedly.")
                    break
            except Exception as ex:
                print(f"[Exploration-Warning] pick() failed – {ex}")
                # We continue with the next object (if any)
                continue

            # Optional: immediate place back to original location
            try:
                obs, reward, done = place(                    # noqa: F405
                    env,
                    task,
                    target_pos=target_pos,
                    approach_distance=0.15,
                    max_steps=120,
                    threshold=0.01,
                    approach_axis="z",
                    timeout=6.0,
                )
            except Exception:
                pass

        # If exploration finished without ‘done’, we are safe
        print("\n===== Exploration Phase Finished =====\n")

    finally:
        # ------------------------------------------------------------------
        #  Step-5:   Proper shutdown of the simulator
        # ------------------------------------------------------------------
        try:
            shutdown_environment(env)
        except Exception:
            pass

    print("===== End of Skeleton Task =====\n")


# Stand-alone execution
if __name__ == "__main__":
    run_skeleton_task()