# run_skeleton_task.py (Completed Version)

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# pull in every predefined skill exactly as provided
from skill_code import *          # noqa: F401, F403

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def run_skeleton_task():
    """Generic skeleton for running any task in the simulation.

    The code has been extended to include
      1)  an *exploration phase* that detects the `handempty` predicate,
      2)  a short demonstration sequence that calls only the predefined
          skills (pick → place → rotate) while respecting that predicate.
    """
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset task to its initial state
        descriptions, obs = task.reset()

        # --- Video Initialisation (optional) ---
        init_video_writers(obs)

        # Wrap task for video recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        #
        # The helper returns a dict  {object_name: np.ndarray(x, y, z)}
        # If nothing is reported we simply finish early.
        positions = get_object_positions()
        if not positions:
            print("[Task] No object positions reported – ending run.")
            return

        # Convert every position to a numpy array to make vector-math safe
        for k, v in positions.items():
            positions[k] = np.asarray(v, dtype=np.float32)

        # ------------------------------------------------------------------
        #  EXPLORATION PHASE – detect the missing predicate “handempty”
        # ------------------------------------------------------------------
        print("\n===== Exploration Phase =====")
        # We try to pick two objects in a row without placing the first one.
        # The second pick must fail, revealing that the pre-condition
        # ‘handempty’ is required.
        object_names = list(positions.keys())
        first_obj = object_names[0]
        first_pos = positions[first_obj]

        print(f"[Exploration] 1st pick attempt on '{first_obj}' …")
        try:
            obs, reward, done = pick(
                env,
                task,
                target_pos=first_pos,
                approach_distance=0.15,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            print("[Exploration] 1st pick succeeded.")
        except Exception as e:
            # If even the first pick fails just log and stop exploration
            print(f"[Exploration] 1st pick failed unexpectedly → {e}")
            return

        # Try to pick another object *without* placing the first one
        second_obj = object_names[1] if len(object_names) > 1 else first_obj
        second_pos = positions[second_obj]
        print(f"[Exploration] 2nd pick attempt on '{second_obj}' (while still holding) …")

        second_pick_succeeded = True
        try:
            obs, reward, done = pick(
                env,
                task,
                target_pos=second_pos,
                approach_distance=0.15,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            print("[Exploration] 2nd pick unexpectedly succeeded ‑ predicate already handled.")
        except Exception as e:
            second_pick_succeeded = False
            print(f"[Exploration] 2nd pick failed as expected → {e}")

        if not second_pick_succeeded:
            print("[Exploration] ⇒ Missing predicate identified: handempty\n")

        # ------------------------------------------------------------------
        #  TASK DEMONSTRATION  (respecting `handempty`)
        # ------------------------------------------------------------------
        print("===== Demonstration Phase =====")
        # 1) Place the currently held object a bit above its original spot
        target_place = first_pos + np.array([0.0, 0.0, 0.05], dtype=np.float32)
        print(f"[Task] Placing '{first_obj}' to {target_place} …")
        try:
            obs, reward, done = place(
                env,
                task,
                target_pos=target_place,
                approach_distance=0.15,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            print("[Task] Place succeeded. Robot’s hand is now empty again.")
        except Exception as e:
            print(f"[Task] Place failed → {e}")
            # If we cannot place there is little we can do – end early
            return

        # 2) Rotate gripper back to neutral orientation for completeness
        print("[Task] Rotating gripper to neutral quaternion …")
        neutral_quat = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32)
        try:
            obs, reward, done = rotate(
                env,
                task,
                target_quat=neutral_quat,
                max_steps=100,
                threshold=0.05,
                timeout=10.0
            )
        except Exception as e:
            print(f"[Task] Rotation failed → {e}")

        print("===== Demonstration Completed =====")

    finally:
        # Always ensure the environment shuts down properly
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()