# run_skeleton_task.py
#
# Executable implementation that follows the exact oracle plan contained in the
# specification (steps 1 – 9).  Only the predefined low-level skills from
# `skill_code.py` are invoked (rotate, move, pick, place, pull).
#
# Goal:
#   1) Rotate the gripper to 90 deg.
#   2) Move to the bottom-drawer handle, grasp it and pull the drawer open.
#   3) Pick the two tomatoes and place them on the plate.
#
# The code is intentionally defensive: if a named object position is not found
# inside the cached dictionary returned by `get_object_positions()`, it queries
# the simulator directly via `pyrep.objects.shape.Shape`.

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape

from env import setup_environment, shutdown_environment

# predefined skills
from skill_code import rotate, move, pick, place, pull

# optional video helpers (no-ops if they are stubbed out)
from video import init_video_writers, recording_step, recording_get_observation

# cached object pose dictionary (may be empty if not pre-generated)
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
# Helper utilities                                                            #
# --------------------------------------------------------------------------- #
def _safe_pose_lookup(name: str, pose_dict: dict):
    """
    Try two ways of fetching an object's position:
      (1) look-up in a cached dictionary,
      (2) fall back to querying the simulator directly.
    If neither path succeeds, returns None.
    """
    if name in pose_dict:
        return np.asarray(pose_dict[name], dtype=float)

    try:
        return np.asarray(Shape(name).get_position(), dtype=float)
    except Exception as ex:
        print(f"[Warning] Could not resolve pose for '{name}': {ex}")
        return None


def _resolve_tomato_positions(pose_dict):
    """
    Different RLBench variants sometimes call the tomato objects either
    “tomato1 / tomato2” or “item1 / item2”.  Resolve them robustly.
    """
    tomato1_pos = _safe_pose_lookup("tomato1", pose_dict) \
                  or _safe_pose_lookup("item1", pose_dict)
    tomato2_pos = _safe_pose_lookup("tomato2", pose_dict) \
                  or _safe_pose_lookup("item2", pose_dict)
    return tomato1_pos, tomato2_pos


# --------------------------------------------------------------------------- #
# Main task runner                                                            #
# --------------------------------------------------------------------------- #
def run_skeleton_task():
    print("\n================  START TASK  ================\n")

    # --------------------------------------------------
    # 1) environment initialisation
    # --------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # optional video recording hooks (safe even if stubs)
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # --------------------------------------------------
        # 2) resolve important object poses
        # --------------------------------------------------
        pose_dict = get_object_positions()

        bottom_side_pos   = _safe_pose_lookup("bottom_side_pos",   pose_dict)
        bottom_anchor_pos = _safe_pose_lookup("bottom_anchor_pos", pose_dict)
        plate_pos         = _safe_pose_lookup("plate",             pose_dict)
        tomato1_pos, tomato2_pos = _resolve_tomato_positions(pose_dict)

        required = {
            "bottom_side_pos":   bottom_side_pos,
            "bottom_anchor_pos": bottom_anchor_pos,
            "plate_pos":         plate_pos,
            "tomato1_pos":       tomato1_pos,
            "tomato2_pos":       tomato2_pos
        }
        missing = [k for k, v in required.items() if v is None]
        if missing:
            raise RuntimeError(f"[Fatal] Cannot continue – missing pose(s): {missing}")

        # --------------------------------------------------
        # 3) execute oracle plan (specification-compliant)
        # --------------------------------------------------
        done   = False
        reward = 0.0

        # -------------------------------------------------- #
        # Step 1: rotate(gripper, zero_deg → ninety_deg)     #
        # -------------------------------------------------- #
        print("\n[Step 1] Rotate gripper to 90 deg")
        quat_90deg = R.from_euler('xyz', [0, 0, np.deg2rad(90)]).as_quat()
        obs, reward, done = rotate(env, task, target_quat=quat_90deg)
        if done:
            print("[Early-Exit] Task finished unexpectedly after Step 1.")
            return

        # -------------------------------------------------- #
        # Step 2: move-to-side (bottom_side_pos)             #
        # -------------------------------------------------- #
        print("\n[Step 2] Move to drawer side pose")
        obs, reward, done = move(env, task, target_pos=bottom_side_pos)
        if done:
            return

        # -------------------------------------------------- #
        # Step 3: move-to-anchor (bottom_anchor_pos)         #
        # -------------------------------------------------- #
        print("\n[Step 3] Move to drawer anchor pose (handle)")
        obs, reward, done = move(env, task, target_pos=bottom_anchor_pos)
        if done:
            return

        # -------------------------------------------------- #
        # Step 4: pick-drawer (grasp handle)                 #
        # -------------------------------------------------- #
        print("\n[Step 4] Grasp drawer handle")
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=bottom_anchor_pos,
                                 approach_distance=0.10,
                                 approach_axis='-z')
        if done:
            return

        # -------------------------------------------------- #
        # Step 5: pull (open drawer)                         #
        # -------------------------------------------------- #
        print("\n[Step 5] Pull to open drawer")
        obs, reward, done = pull(env,
                                 task,
                                 pull_distance=0.20,   # 20 cm outward
                                 pull_axis='x')
        if done:
            return

        # -------------------------------------------------- #
        # Step 6: pick tomato1                               #
        # -------------------------------------------------- #
        print("\n[Step 6] Pick Tomato #1")
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=tomato1_pos,
                                 approach_distance=0.12,
                                 approach_axis='-z')
        if done:
            return

        # -------------------------------------------------- #
        # Step 7: place tomato1 on plate                     #
        # -------------------------------------------------- #
        print("\n[Step 7] Place Tomato #1 on plate")
        obs, reward, done = place(env,
                                  task,
                                  target_pos=plate_pos,
                                  approach_distance=0.12,
                                  approach_axis='-z')
        if done:
            return

        # -------------------------------------------------- #
        # Step 8: pick tomato2                               #
        # -------------------------------------------------- #
        print("\n[Step 8] Pick Tomato #2")
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=tomato2_pos,
                                 approach_distance=0.12,
                                 approach_axis='-z')
        if done:
            return

        # -------------------------------------------------- #
        # Step 9: place tomato2 on plate                     #
        # -------------------------------------------------- #
        print("\n[Step 9] Place Tomato #2 on plate")
        obs, reward, done = place(env,
                                  task,
                                  target_pos=plate_pos,
                                  approach_distance=0.12,
                                  approach_axis='-z')
        if done:
            return

        # --------------------------------------------------
        # SUCCESS
        # --------------------------------------------------
        print("\n================  TASK SUCCESS  ================")
        print("Final reward:", reward)

    finally:
        shutdown_environment(env)
        print("\n================  END TASK  ====================\n")


if __name__ == "__main__":
    run_skeleton_task()