# run_skeleton_task.py (Completed Version)

import re

import sys

import numpy as np

from pyrep.objects.shape import Shape            # do NOT remove (needed by env)

from pyrep.objects.proximity_sensor import ProximitySensor   # do NOT remove

from env import setup_environment, shutdown_environment

from skill_code import *          # DO NOT redefine any skill functions

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

# --------------------------------------------------------------------------- #

# -------------------------  SIMPLE PDDL PARSER HELPERS  -------------------- #

# --------------------------------------------------------------------------- #

IGNORED_PDDL_TOKENS = {

    'and', 'or', 'not', 'forall', 'exists', 'when',

    ':parameters', ':precondition', ':effect', ':types',

    ':action', ':predicates', ':requirements', ':strips',

    ':typing', ':negative-preconditions', ':conditional-effects',

    ':universal-preconditions', ':equality', ':disjunctive-preconditions'

}

def _clean_token(tok: str) -> str:

    """Remove parentheses and extra characters, return lowercase token."""

    return tok.replace('(', '').replace(')', '').strip().lower()

def extract_declared_predicates(domain_str: str) -> set:

    """Return the set of predicate names declared in (:predicates …)."""

    declared = set()

    # Locate the (:predicates section

    predicates_section = re.search(r'\(:predicates(.*?)\)', domain_str, re.S)

    if not predicates_section:

        return declared

    section_text = predicates_section.group(1)

    for line in section_text.splitlines():

        line = line.strip()

        if not line or line.startswith(';'):

            continue

        # A predicate line starts with '(' followed by the name

        if line[0] != '(':

            continue

        predicate_name = _clean_token(line.split()[0])

        if predicate_name and predicate_name not in IGNORED_PDDL_TOKENS:

            declared.add(predicate_name)

    return declared

def extract_used_predicates(domain_str: str) -> set:

    """Return the set of predicate names that appear anywhere inside '(' … ')' tokens

       minus ignored keywords."""

    used = set()

    # remove comments

    domain_str_no_comments = re.sub(r';.*', '', domain_str)

    tokens = re.findall(r'\([^\)]+\)', domain_str_no_comments)

    for tok_group in tokens:

        first_token = _clean_token(tok_group.split()[0])

        if first_token and first_token not in IGNORED_PDDL_TOKENS:

            used.add(first_token)

    return used

# --------------------------------------------------------------------------- #

# ---------------------  EXPLORATION: FIND MISSING PREDICATE  --------------- #

# --------------------------------------------------------------------------- #

def find_missing_predicates(domain_str: str) -> set:

    declared = extract_declared_predicates(domain_str)

    used = extract_used_predicates(domain_str)

    missing = used - declared

    return missing

def run_skeleton_task():

    '''Generic skeleton for running any task in your simulation.'''

    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()          # RLBench environment creation

    try:

        descriptions, obs = task.reset()

        init_video_writers(obs)

        # Wrap task.step and get_observation so that every frame is recorded

        original_step = task.step

        task.step = recording_step(original_step)

        original_get_obs = task.get_observation

        task.get_observation = recording_get_observation(original_get_obs)

        # ---------------------------------------------------------- #

        #           1)  EXPLORATION PHASE – FIND MISSING PRED       #

        # ---------------------------------------------------------- #

        print("\n[Exploration] Parsing exploration domain to locate "

              "undefined predicates used in actions …")

        exploration_domain_str = """

        (define (domain exploration)

          (:requirements :strips :typing :conditional-effects :universal-preconditions)

          (:types 

            robot object location

          )

          (:predicates

            (robot-at ?r - robot ?loc - location)

            (at ?obj - object ?loc - location)

            (identified ?obj - object)

            (temperature-known ?obj - object)

            (holding ?obj - object)

            (handempty)

            (weight-known ?obj - object)

            (durability-known ?obj - object)

          )

          (:action move

            :parameters (?r - robot ?from - location ?to - location)

            :precondition (robot-at ?r ?from)

            :effect (and

              (not (robot-at ?r ?from))

              (robot-at ?r ?to)

              (forall (?obj - object)

                (when (at ?obj ?to)

                  (identified ?obj)

                )

              )

            )

          )

          (:action move

            :parameters (?r - robot ?from - location ?to - location)

            :precondition (robot-at ?r ?from)

            :effect (and

              (not (robot-at ?r ?from))

              (robot-at ?r ?to)

              (forall (?obj - object)

                (when (at ?obj ?to)

                  (temperature-known ?obj)

                )

              )

            )

          )

          (:action pick

            :parameters (?r - robot ?obj - object ?loc - location)

            :precondition (and

               (robot-at ?r ?loc)

               (at ?obj ?loc)

               (handempty)

            )

            :effect (and

              (holding ?obj)

              (not (handempty))

              (not (at ?obj ?loc))

              (weight-known ?obj)

            )

          )

          (:action pick

            :parameters (?r - robot ?obj - object ?loc - location)

            :precondition (and

               (robot-at ?r ?loc)

               (at ?obj ?loc)

               (handempty)

            )

            :effect (and

              (holding ?obj)

              (not (handempty))

              (not (at ?obj ?loc))

              (durability-known ?obj)

            )

          )

          (:action pull

            :parameters (?r - robot ?obj - object ?loc - location)

            :precondition (and

               (robot-at ?r ?loc)

               (at ?obj ?loc)

               (holding ?obj)           

               (not (lock-known ?obj)) 

            )

            :effect (lock-known ?obj)   

          )

        )

        """

        missing_predicates = find_missing_predicates(exploration_domain_str)

        if missing_predicates:

            print(f"[Exploration] Missing predicate(s) detected: {missing_predicates}")

        else:

            print("[Exploration] No missing predicate detected.")

        # ---------------------------------------------------------- #

        #     2)  DEMO – SAFE CALL TO A SKILL (with sanity check)    #

        # ---------------------------------------------------------- #

        # NOTE: We keep the demo very small to avoid task failures.

        # We demonstrate feedback fix – verifying object existence

        # before invoking a skill (pick/place/etc.).

        positions = get_object_positions()      # Dict[str, np.ndarray]

        if positions:

            # pick the first object for demonstration only

            obj_name, obj_pos = next(iter(positions.items()))

            print(f"[Demo] Will attempt to pick the first detected object '{obj_name}'.")

            if obj_name not in positions:

                print(f"[Demo] '{obj_name}' not found in current positions dict – skipping.")

            else:

                try:

                    # The pick skill signature is expected to be:

                    # pick(env, task, target_pos, **kwargs)

                    # The following values are generic; individual skill

                    # implementations internally clamp speed/limits.

                    obs, reward, done = pick(

                        env=env,

                        task=task,

                        target_pos=np.array(obj_pos),

                        approach_distance=0.15,

                        max_steps=120,

                        threshold=0.01,

                        approach_axis='z',

                        timeout=10.0

                    )

                    if done:

                        print("[Demo] Task reported completion after pick.")

                except Exception as e:

                    print(f"[Demo] Pick attempt failed gracefully: {e}")

        else:

            print("[Demo] No objects returned by get_object_positions(), "

                  "so skipping skill demonstration.")

    finally:

        # Guarantee clean shutdown even if an exception occurs

        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":

    run_skeleton_task()
