import numpy as np
from scipy.spatial.transform import Rotation as R

# RL-Bench needs these two imports even if we never address the objects directly
from pyrep.objects.shape import Shape                # noqa: F401
from pyrep.objects.proximity_sensor import ProximitySensor   # noqa: F401

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ---------------------------------------------------------------------
#  Pre-implemented low-level skills  (DO NOT redefine!  Just import.)
# ---------------------------------------------------------------------
from skill_code import rotate, move, pick, pull, place


# =============================================================================
#                               ORACLE CONTROLLER
# =============================================================================
def run_skeleton_task() -> None:
    """
    Executes the oracle plan required for this task:

        1) rotate(gripper, zero_deg, ninety_deg)
        2) move-to-side   (nowhere-pos      ➜  side-pos-bottom)
        3) move-to-anchor (side-pos-bottom  ➜  anchor-pos-bottom)
        4) pick-drawer    (grasp handle at anchor-pos-bottom)
        5) pull           (open drawer along −X)
        6) pick           (rubbish from table)
        7) place          (rubbish into bin)

    Only the predefined skills from `skill_code` are used.
    """

    print("\n==========  STARTING ORACLE TASK  ==========\n")

    # ------------------------------------------------------------------
    #  Environment initialisation
    # ------------------------------------------------------------------
    env, task = setup_environment()                          # RL-Bench helper
    try:
        descriptions, obs = task.reset()                     # fresh initial state

        # ------------------------------------------------------------------
        #  Optional video capture
        # ------------------------------------------------------------------
        init_video_writers(obs)
        original_step    = task.step
        task.step        = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ------------------------------------------------------------------
        #  Retrieve all required object / waypoint positions
        # ------------------------------------------------------------------
        positions = get_object_positions()

        # Drawer related way-points -----------------------------------------
        try:
            side_pos_bottom   = np.array(positions['bottom_side_pos'])     # side-pos-bottom
            anchor_pos_bottom = np.array(positions['bottom_anchor_pos'])   # anchor-pos-bottom
        except KeyError as err:
            raise RuntimeError(f"[Setup] Missing drawer waypoint in object_positions: {err}")

        # A generic “nowhere” waypoint (safe hovering pose); fall back to side-pos
        nowhere_pos = np.array(positions.get('waypoint1', side_pos_bottom))

        # Rubbish & bin ------------------------------------------------------
        try:
            rubbish_pos = np.array(positions['rubbish'])
            bin_pos     = np.array(positions['bin'])
        except KeyError as err:
            raise RuntimeError(f"[Setup] Missing required object in object_positions: {err}")

        # ==================================================================
        #                          ORACLE PLAN
        # ==================================================================

        # ------------------------------------------------------------------
        #  STEP-1  Rotate gripper from 0° to 90° around Z
        # ------------------------------------------------------------------
        print("\n[STEP-1] rotate(gripper, zero_deg, ninety_deg)")
        ninety_deg_quat = R.from_euler('xyz', [0.0, 0.0, np.pi / 2]).as_quat()   # xyzw
        obs, reward, done = rotate(env, task, target_quat=ninety_deg_quat)
        if done:
            print("[Finish] Episode ended during rotation.")
            return

        # ------------------------------------------------------------------
        #  STEP-2  Move to drawer’s side position
        # ------------------------------------------------------------------
        print("\n[STEP-2a] move(gripper ➜ nowhere_pos)")
        obs, reward, done = move(env, task, target_pos=nowhere_pos)
        if done:
            print("[Finish] Episode ended while moving to nowhere_pos.")
            return

        print("\n[STEP-2b] move-to-side (nowhere_pos ➜ side-pos-bottom)")
        obs, reward, done = move(env, task, target_pos=side_pos_bottom)
        if done:
            print("[Finish] Episode ended while moving to side-pos-bottom.")
            return

        # ------------------------------------------------------------------
        #  STEP-3  Move to drawer’s anchor position
        # ------------------------------------------------------------------
        print("\n[STEP-3] move-to-anchor (side-pos-bottom ➜ anchor-pos-bottom)")
        obs, reward, done = move(env, task, target_pos=anchor_pos_bottom)
        if done:
            print("[Finish] Episode ended while moving to anchor-pos-bottom.")
            return

        # ------------------------------------------------------------------
        #  STEP-4  Grasp (pick) the drawer handle
        #          – generic “pick” skill closes the gripper on the handle
        # ------------------------------------------------------------------
        print("\n[STEP-4] pick-drawer (implemented via generic pick)")
        obs, reward, done = pick(
            env, task,
            target_pos=anchor_pos_bottom,
            approach_distance=0.10,
            approach_axis='x'          # approach drawer handle along +X
        )
        if done:
            print("[Finish] Episode ended while picking drawer handle.")
            return

        # ------------------------------------------------------------------
        #  STEP-5  Pull the drawer open (15 cm along −X)
        # ------------------------------------------------------------------
        print("\n[STEP-5] pull(gripper, bottom)  – 15 cm along −X")
        obs, reward, done = pull(
            env, task,
            pull_distance=0.15,
            pull_axis='-x'
        )
        if done:
            print("[Finish] Episode ended while pulling drawer.")
            return

        # ------------------------------------------------------------------
        #  STEP-6  Pick the rubbish from the table
        # ------------------------------------------------------------------
        print("\n[STEP-6a] move above rubbish (safety hover)")
        obs, reward, done = move(
            env, task,
            target_pos=rubbish_pos + np.array([0.0, 0.0, 0.10])
        )
        if done:
            print("[Finish] Episode ended while hovering above rubbish.")
            return

        print("\n[STEP-6] pick(rubbish, table)")
        obs, reward, done = pick(
            env, task,
            target_pos=rubbish_pos,
            approach_distance=0.10,
            approach_axis='-z'         # descend onto rubbish
        )
        if done:
            print("[Finish] Episode ended while picking rubbish.")
            return

        # ------------------------------------------------------------------
        #  STEP-7  Place the rubbish into the bin
        # ------------------------------------------------------------------
        print("\n[STEP-7a] move above bin (safety hover)")
        obs, reward, done = move(
            env, task,
            target_pos=bin_pos + np.array([0.0, 0.0, 0.15])
        )
        if done:
            print("[Finish] Episode ended while hovering above bin.")
            return

        print("\n[STEP-7] place(rubbish, bin)")
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos,
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Finish] Episode ended after placing rubbish.")
        else:
            print("\n[Success] All oracle plan steps executed (done flag is False, "
                  "but target sequence completed).")

    except Exception as exc:
        print(f"\n[ERROR] Exception during task execution: {exc}")

    finally:
        shutdown_environment(env)
        print("\n==========  ORACLE TASK FINISHED  ==========\n")


# -----------------------------------------------------------------------------
#  Entry-point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    run_skeleton_task()