# run_skeleton_task.py  (completed)

import re
import traceback

import numpy as np
from pyrep.objects.shape import Shape           #  ←  keep (required by env / skills)
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *                         #  ←  predefined motion‐skills (pick, place, move, rotate, pull …)
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Utility: very light PDDL parser that discovers predicates that are USED in
#          :precondition / :effect clauses but were never declared in the
#          (:predicates …) list – i.e. “missing predicates”.
# ---------------------------------------------------------------------------
def explore_missing_predicates(domain_str: str):
    """Return a set with predicate symbols that are used but never declared."""
    defined = set()
    used    = set()

    # 1) collect declared predicates
    pred_block = re.search(r'\(:predicates(.*?)\)', domain_str, flags=re.S)
    if pred_block:
        tokens = re.findall(r'\(([^\s\()]+)', pred_block.group(1))
        defined.update(tokens)

    # 2) scan every action for predicate symbols
    for action_block in re.findall(r'\(:action.*?\)', domain_str, flags=re.S):
        tokens = re.findall(r'\(([^\s\()]+)', action_block)
        for tok in tokens:
            if tok not in {"and", "not", "forall", "when", "effect",
                           "precondition", "parameters", "="}:
                used.add(tok)

    # 3) missing = used – defined
    return used.difference(defined)


# ────────────────────────────────────────────────────────────────────────────
#  Embedded domain strings (abridged – only needed for the exploration phase)
# ────────────────────────────────────────────────────────────────────────────
COMBINED_DOMAIN_STR = r"""
(define (domain combined-domain)
  (:requirements :strips :typing :negative-preconditions :equality :disjunctive-preconditions)
  (:types object location drawer - object gripper - object position - object angle - object)
  (:predicates
    (at ?obj - object ?loc - location)
    (holding ?obj - object)
    (handempty)
    (is-locked ?d - drawer)
    (is-open ?d - drawer)
    (rotated ?g - gripper ?a - angle)
    (gripper-at ?g - gripper ?p - position)
    (holding-drawer ?g - gripper ?d - drawer)
    (is-side-pos ?p - position ?d - drawer)
    (is-anchor-pos ?p - position ?d - drawer)
  )
  (:action pull
    :parameters (?g - gripper ?d - drawer)
    :precondition (and
      (holding-drawer ?g ?d)
      (not (is-locked ?d))
      (not (is-open ?d))
    )
    :effect (is-open ?d)
  )
)
"""

EXPLORATION_DOMAIN_STR = r"""
(define (domain exploration)
  (:requirements :strips :typing :conditional-effects :universal-preconditions)
  (:types robot object location)
  (:predicates
    (robot-at ?r - robot ?loc - location)
    (at ?obj - object ?loc - location)
    (identified ?obj - object)
    (temperature-known ?obj - object)
    (holding ?obj - object)
    (handempty)
    (weight-known ?obj - object)
    (durability-known ?obj - object)
  )
  (:action pull
    :parameters (?r - robot ?obj - object ?loc - location)
    :precondition (and
       (robot-at ?r ?loc)
       (at ?obj ?loc)
       (holding ?obj)
       (not (lock-known ?obj))    ;  <──  referenced but never declared
    )
    :effect (lock-known ?obj)
  )
)
"""


def run_skeleton_task():
    """Generic entry‐point for the RLBench task – now containing:
       1) a predicate‐exploration phase
       2) a very small demo plan that shows safe calls to the predefined skills.
    """
    print("===== Starting Skeleton Task =====")

    # ─────────────  Environment initialisation  ─────────────
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional recording infrastructure
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ───────────────────────────────────────────────────
        #   1)  EXPLORATION PHASE – find missing predicates
        # ───────────────────────────────────────────────────
        for name, domain in [("combined-domain",    COMBINED_DOMAIN_STR),
                             ("exploration-domain", EXPLORATION_DOMAIN_STR)]:
            missing = explore_missing_predicates(domain)
            if missing:
                print(f"[exploration] In '{name}' the following predicates are "
                      f"USED but not DECLARED: {sorted(missing)}")
            else:
                print(f"[exploration] No missing predicates detected in '{name}'.")

        #  →  According to the exploration result we expect to see:
        #       ['lock-known']   in the exploration domain
        #     (this is the “missing predicate” requested in the instructions)

        # ───────────────────────────────────────────────────
        #   2)  SIMPLE DEMO PLAN (pick first known object)
        # ───────────────────────────────────────────────────
        positions = get_object_positions()          # {name: np.array([...]), …}
        if not positions:
            print("[task] No object positions returned – nothing to do.")
            return

        # Choose the first entry deterministically
        obj_name, obj_pos = next(iter(positions.items()))
        print(f"[task] Selected '{obj_name}' @ {obj_pos} as demo target.")

        # Make sure the coordinate is numpy array (skill expects np.ndarray)
        target_pos = np.asarray(obj_pos, dtype=float)

        # --- Safety check before calling `pick` (feedback related) -----------
        try:
            # Just verifying that the position seems plausible (within workspace)
            if np.linalg.norm(target_pos) > 10:     # arbitrary sanity radius
                raise ValueError("Target seems too far – skipping pick.")

            # Call the predefined skill (all heavy-lifting inside)
            obs, reward, done = pick(
                env,
                task,
                target_pos=target_pos,
                approach_distance=0.15,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=15.0
            )

            if done:
                print("[task] Episode finished early (environment signalled 'done').")
                return

            # If pick succeeds, place the object back to the same position
            # (purely for demonstration)
            obs, reward, done = place(
                env,
                task,
                target_pos=target_pos + np.array([0.0, 0.0, 0.05]),   # slightly above
                approach_distance=0.10,
                max_steps=120,
                threshold=0.01,
                approach_axis='-z',
                timeout=15.0
            )

        except Exception as e:
            print("[task] Exception during skill execution:")
            traceback.print_exc()

    finally:
        # ───────────────  Always shut down cleanly  ───────────────
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
