# run_skeleton_task.py

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *          # move, pick, place, rotate, pull
from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)
from object_positions import get_object_positions


def _safe_position(name: str, positions: dict):
    """
    Utility that returns the position of a named object.  It first looks
    up the cached dictionary coming from `get_object_positions()`.  If the
    entry is missing, it falls back to a direct query to the simulator via
    `pyrep.objects.Shape`.
    """
    if name in positions and positions[name] is not None:
        return np.asarray(positions[name], dtype=np.float32)
    try:
        return np.asarray(Shape(name).get_position(), dtype=np.float32)
    except Exception as e:
        raise RuntimeError(
            f"[Task] Cannot find the position for object '{name}'."
        ) from e


def execute_oracle_plan(env, task, positions):
    """
    Executes the hand-authored oracle plan given in the specification.
    Every low-level behaviour is delegated to pre-defined skill functions
    imported from `skill_code`.
    """

    # ------------------------------------------------------------------
    # Step-1  : rotate gripper from 0° to 90° about Z-axis
    # ------------------------------------------------------------------
    print("\n[Plan] Step-1: rotate gripper 0 → 90 degrees (domain ‘zero_deg’ → ‘ninety_deg’).")
    target_quat_z90 = R.from_euler("z", 90, degrees=True).as_quat()  # (x,y,z,w)
    obs, reward, done = rotate(env, task, target_quat_z90)
    if done:
        print("[Plan] Task finished prematurely after step-1.")
        return

    # ------------------------------------------------------------------
    # Step-2  : move to the drawer’s side position
    # ------------------------------------------------------------------
    print("\n[Plan] Step-2: move gripper → ‘side-pos-middle’.")
    side_pos = _safe_position("side-pos-middle", positions)
    obs, reward, done = move(env, task, side_pos)
    if done:
        print("[Plan] Task finished prematurely after step-2.")
        return

    # ------------------------------------------------------------------
    # Step-3  : move to the drawer’s anchor (handle) position
    # ------------------------------------------------------------------
    print("\n[Plan] Step-3: move gripper → ‘anchor-pos-middle’.")
    anchor_pos = _safe_position("anchor-pos-middle", positions)
    obs, reward, done = move(env, task, anchor_pos)
    if done:
        print("[Plan] Task finished prematurely after step-3.")
        return

    # ------------------------------------------------------------------
    # Step-4  : pick drawer handle (approximated with generic ‘pick’)
    # ------------------------------------------------------------------
    print("\n[Plan] Step-4: pick drawer handle at anchor position.")
    obs, reward, done = pick(env, task, anchor_pos, approach_axis="z")
    if done:
        print("[Plan] Task finished prematurely after step-4.")
        return

    # ------------------------------------------------------------------
    # Step-5  : pull drawer open
    # ------------------------------------------------------------------
    print("\n[Plan] Step-5: pull drawer (along +X, 15 cm).")
    obs, reward, done = pull(env, task, pull_distance=0.15, pull_axis="x")
    if done:
        print("[Plan] Task finished prematurely after step-5.")
        return

    # ------------------------------------------------------------------
    # Step-6/7  : pick ‘tomato1’ and place on ‘plate’
    # ------------------------------------------------------------------
    print("\n[Plan] Step-6: pick ‘tomato1’.")
    tomato1_pos = _safe_position("tomato1", positions)
    obs, reward, done = pick(env, task, tomato1_pos, approach_axis="z")
    if done:
        print("[Plan] Task finished prematurely after picking tomato1.")
        return

    print("\n[Plan] Step-7: place ‘tomato1’ on ‘plate’.")
    plate_pos = _safe_position("plate", positions)
    obs, reward, done = place(env, task, plate_pos, approach_axis="z")
    if done:
        print("[Plan] Task finished prematurely after placing tomato1.")
        return

    # ------------------------------------------------------------------
    # Step-8/9  : pick ‘tomato2’ and place on ‘plate’
    # ------------------------------------------------------------------
    print("\n[Plan] Step-8: pick ‘tomato2’.")
    tomato2_pos = _safe_position("tomato2", positions)
    obs, reward, done = pick(env, task, tomato2_pos, approach_axis="z")
    if done:
        print("[Plan] Task finished prematurely after picking tomato2.")
        return

    print("\n[Plan] Step-9: place ‘tomato2’ on ‘plate’.")
    obs, reward, done = place(env, task, plate_pos, approach_axis="z")
    if done:
        print("[Plan] Task finished prematurely after placing tomato2.")
        return

    # ------------------------------------------------------------------
    # Completion message
    # ------------------------------------------------------------------
    print("\n[Plan] Oracle plan executed.  Goal conditions should now be satisfied!")


def run_skeleton_task():
    """Entry-point similar to the original skeleton, now executing the oracle plan."""
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset simulation
        descriptions, obs = task.reset()

        # Optional video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Retrieve all object positions
        positions = get_object_positions()

        # Run the oracle plan
        execute_oracle_plan(env, task, positions)

    finally:
        # Shutdown no matter what
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()