# run_skeleton_task.py

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape          # (kept; required for some RLBench scenes)
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# keep explicit star‑import out of the real code, instead pull the skills we need
from skill_code import rotate, move, pick, place, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _safe_call(skill_fn, *args, **kwargs):
    """
    Wrapper that executes a skill and gracefully aborts the task if the
    environment signals `done=True` or an exception is raised.
    """
    try:
        obs, reward, done = skill_fn(*args, **kwargs)
        if done:
            print(f"[Abort] Environment signalled done during: {skill_fn.__name__}")
        return obs, reward, done
    except Exception as e:
        print(f"[Error] {skill_fn.__name__} raised an exception: {e}")
        raise


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # --------------------------------------------------
    # 1) Environment boot‑up
    # --------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        # ----------  optional video recorder ----------
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)
        # ----------------------------------------------

        # --------------------------------------------------
        # 2) Fetch all relevant object positions
        # --------------------------------------------------
        positions = get_object_positions()
        bottom_side_pos   = positions.get('bottom_side_pos')
        bottom_anchor_pos = positions.get('bottom_anchor_pos')
        rubbish_pos       = positions.get('rubbish')
        bin_pos           = positions.get('bin')

        if None in [bottom_side_pos, bottom_anchor_pos, rubbish_pos, bin_pos]:
            raise RuntimeError("Missing one or more required object positions.")

        # --------------------------------------------------
        # 3) Follow the oracle‑style plan
        # --------------------------------------------------
        done = False

        # ---- Step‑1 : rotate (0 deg  ➜  90 deg about +Z) ----
        current_quat = task.get_observation().gripper_pose[3:7]
        target_quat  = (R.from_quat(current_quat) * R.from_euler('z', 90, degrees=True)).as_quat()
        _, _, done = _safe_call(rotate, env, task, target_quat)
        if done:
            return

        # ---- Step‑2 : move to the drawer’s side handle position ----
        _, _, done = _safe_call(move, env, task, np.array(bottom_side_pos))
        if done:
            return

        # ---- Step‑3 : move into the anchor (front‑and‑centre) pose ----
        _, _, done = _safe_call(move, env, task, np.array(bottom_anchor_pos))
        if done:
            return

        # ---- Step‑4 : pick the drawer handle ----
        _, _, done = _safe_call(
            pick,
            env, task,
            target_pos=np.array(bottom_anchor_pos),
            approach_axis='z'
        )
        if done:
            return

        # ---- Step‑5 : pull the drawer outwards along +X (0.15 m) ----
        _, _, done = _safe_call(
            pull,
            env, task,
            pull_distance=0.15,
            pull_axis='x'
        )
        if done:
            return

        # ---- Step‑6 : pick the rubbish from the table ----
        _, _, done = _safe_call(
            pick,
            env, task,
            target_pos=np.array(rubbish_pos),
            approach_axis='z'
        )
        if done:
            return

        # ---- Step‑7 : place the rubbish in the bin ----
        _, _, done = _safe_call(
            place,
            env, task,
            target_pos=np.array(bin_pos),
            approach_axis='z'
        )

        # --------------------------------------------------
        # 4) Final status
        # --------------------------------------------------
        if done:
            print("[Task] Environment finished early (done=True).")
        else:
            print("[Task] Plan executed – Goal should be satisfied!")
    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()