# run_skeleton_task.py  (Completed Implementation)

import re
import time
import numpy as np

# ===== Keep all original skeleton imports =====
from pyrep.objects.shape import Shape                     # noqa: F401  (kept for consistency – might be used inside skills)
from pyrep.objects.proximity_sensor import ProximitySensor  # noqa: F401
from env import setup_environment, shutdown_environment
from skill_code import rotate, pick, place, move, pull    # we never redefine any of these
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ----------------------------------------------------------
#  Exploration-helper: find predicates referred to in actions
#  that are NOT declared in the (:predicates …) section
# ----------------------------------------------------------
EXPLORATION_DOMAIN_PDDL = """
(define (domain exploration)
  (:requirements :strips :typing :conditional-effects :universal-preconditions)
  (:types robot object location)
  (:predicates
    (robot-at ?r - robot ?loc - location)
    (at ?obj - object ?loc - location)
    (identified ?obj - object)
    (temperature-known ?obj - object)
    (holding ?obj - object)
    (handempty)
    (weight-known ?obj - object)
    (durability-known ?obj - object)
  )

  (:action move
    :parameters (?r - robot ?from - location ?to - location)
    :precondition (robot-at ?r ?from)
    :effect (and
      (not (robot-at ?r ?from))
      (robot-at ?r ?to)
      (forall (?obj - object)
        (when (at ?obj ?to)
          (identified ?obj)
        )
      )
    )
  )

  (:action move
    :parameters (?r - robot ?from - location ?to - location)
    :precondition (robot-at ?r ?from)
    :effect (and
      (not (robot-at ?r ?from))
      (robot-at ?r ?to)
      (forall (?obj - object)
        (when (at ?obj ?to)
          (temperature-known ?obj)
        )
      )
    )
  )

  (:action pick
    :parameters (?r - robot ?obj - object ?loc - location)
    :precondition (and
       (robot-at ?r ?loc)
       (at ?obj ?loc)
       (handempty)
    )
    :effect (and
      (holding ?obj)
      (not (handempty))
      (not (at ?obj ?loc))
      (weight-known ?obj)
    )
  )

  (:action pick
    :parameters (?r - robot ?obj - object ?loc - location)
    :precondition (and
       (robot-at ?r ?loc)
       (at ?obj ?loc)
       (handempty)
    )
    :effect (and
      (holding ?obj)
      (not (handempty))
      (not (at ?obj ?loc))
      (durability-known ?obj)
    )
  )

  (:action pull
    :parameters (?r - robot ?obj - object ?loc - location)
    :precondition (and
       (robot-at ?r ?loc)
       (at ?obj ?loc)
       (holding ?obj)
       (not (lock-known ?obj))
    )
    :effect (lock-known ?obj)
  )
)
"""


def _extract_predicates(domain_str):
    """Return a set of predicate names declared in the (:predicates …) block."""
    m = re.search(r'\(:predicates([^)]*)\)', domain_str, re.S | re.I)
    if not m:
        return set()
    predicates_block = m.group(1)
    # first token immediately after each '('
    names = re.findall(r'\(\s*([^\s()]+)', predicates_block)
    return set(names)


def _extract_all_tokens(domain_str):
    """Return every token that appears between parentheses in the file."""
    return re.findall(r'\(\s*([^\s()]+)', domain_str)


def find_missing_predicates(domain_str):
    declared = _extract_predicates(domain_str)
    all_tokens = _extract_all_tokens(domain_str)
    # Ignore PDDL keywords that are not predicates
    keywords = {
        'define', 'domain', ':requirements', ':types', ':predicates',
        ':action', ':parameters', ':precondition', ':effect',
        'and', 'not', 'forall', 'when', 'or', 'exists'
    }
    missing = set()
    for tok in all_tokens:
        if tok in keywords:
            continue
        # If token never declared in predicates, flag it
        if tok not in declared:
            missing.add(tok)
    # Return only those that are genuinely missing by
    # filtering out type names (they appear in :types)
    type_names = re.findall(r':types([^)]*)', domain_str, re.S | re.I)
    if type_names:
        for part in type_names:
            for tname in part.split():
                missing.discard(tname.strip())
    return sorted(missing)


# ----------------------------------------------------------
#  Quaternion utility functions used for validating target
# ----------------------------------------------------------
def _quat_norm(q):
    return np.linalg.norm(q)


def _is_valid_quaternion(q, eps=1e-3):
    """Quaternion should be 4-D, each component in [-1,1], and norm≈1."""
    if q.shape != (4,):
        return False
    if np.any(np.abs(q) > 1.0 + eps):
        return False
    return abs(_quat_norm(q) - 1.0) < 5e-2      # allow small tolerance


# ----------------------------------------------------------
#  Safe wrapper around the provided rotate() skill
# ----------------------------------------------------------
def safe_rotate(env, task, target_quat, max_steps=100, threshold=0.05, timeout=10.0):
    """Validate quaternion & protect against endless looping."""
    print("[safe_rotate] Requested quaternion:", target_quat)
    target_quat = np.asarray(target_quat).astype(np.float32)
    if not _is_valid_quaternion(target_quat):
        print("[safe_rotate] ERROR: Target quaternion invalid – skipping rotation.")
        # Simply return current observation without doing anything
        obs = task.get_observation()
        return obs, 0.0, False

    # Call the original rotate implementation; we rely on its internal
    # timeout, but we also protect with a wall-clock timer.
    wall_clock_start = time.time()
    obs, reward, done = rotate(
        env,
        task,
        target_quat,
        max_steps=max_steps,
        threshold=threshold,
        timeout=timeout
    )

    elapsed = time.time() - wall_clock_start
    if elapsed > (timeout + 2.0):       # +2s slack
        print("[safe_rotate] WARNING: Rotation exceeded expected timeout.")

    return obs, reward, done


# ----------------------------------------------------------
#  Main routine that brings everything together
# ----------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # --------------------------------------------------
    #  1) Detect missing predicates in exploration PDDL
    # --------------------------------------------------
    missing_preds = find_missing_predicates(EXPLORATION_DOMAIN_PDDL)
    if missing_preds:
        print("[Exploration] Detected predicates referenced but NOT declared:", missing_preds)
    else:
        print("[Exploration] No missing predicates detected.")

    # --------------------------------------------------
    #  2) Standard RLBench environment handling
    # --------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) video recording
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # --------------------------------------------------
        #  3) Gather object positions from helper module
        # --------------------------------------------------
        positions = get_object_positions()        # dict: name -> 3-D position
        print("[Info] Known object positions:", positions)

        # --------------------------------------------------
        #  4) Minimal demo plan:   (purely illustrative)
        #     – we showcase the improved safe_rotate wrapper.
        # --------------------------------------------------
        current_quat = np.asarray(obs.gripper_pose[3:7]).copy()
        print("[Plan] Rotating gripper to its *current* orientation (should be no-op).")
        obs, reward, done = safe_rotate(
            env,
            task,
            target_quat=current_quat,
            max_steps=50,
            threshold=0.01,
            timeout=5.0
        )
        if done:
            print("[Plan] Task ended during safe_rotate – exiting early.")
            return

        # --------------------------------------------------
        # 5) If there is a drawer handle available, try pull
        # --------------------------------------------------
        handle_name = None
        for name in positions.keys():
            if "handle" in name.lower() or "drawer" in name.lower():
                handle_name = name
                break

        if handle_name:
            print(f"[Plan] Attempting to move to and pull handle: {handle_name}")

            # Step A: move the robot toward the handle
            target_pos = positions[handle_name]
            obs, reward, done = move(
                env,
                task,
                target_pos=target_pos,
                approach_distance=0.10,
                max_steps=120,
                threshold=0.02,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Plan] Task ended during move – exiting early.")
                return

            # Step B: try to pull
            obs, reward, done = pull(env, task)
            if done:
                print("[Plan] Task ended immediately after pull – exiting early.")
                return

        else:
            print("[Plan] No drawer/handle object recognised – skipping pull demo.")

        # --------------------------------------------------
        # 6) Final message
        # --------------------------------------------------
        print("[Plan] Demo plan completed – you would add the full oracle plan here.")

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
