# run_task.py

import numpy as np
from scipy.spatial.transform import Rotation as R

# Skeleton-prescribed imports – **must not be removed**
from pyrep.objects.shape import Shape                    # noqa: F401
from pyrep.objects.proximity_sensor import ProximitySensor  # noqa: F401

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# Pre-implemented skills (DO NOT redefine or re-implement)
from skill_code import rotate, move, pull, pick, place


# --------------------------------------------------------------------------- #
#                               Helper Functions                              #
# --------------------------------------------------------------------------- #
def _safe_call(skill_fn, *args, **kwargs):
    """
    Execute a skill while printing an informative message when any exception
    occurs.  The original exception is re-raised so that the outer try/finally
    still shuts the simulator down cleanly.
    """
    try:
        return skill_fn(*args, **kwargs)
    except Exception as exc:            # noqa: BLE001 – intentional for robustness
        print(f"[ERROR] Skill <{skill_fn.__name__}> failed:\n{exc}")
        raise


def _get_pos(positions: dict, name: str) -> np.ndarray:
    """
    Convenience wrapper that retrieves a 3-D position from the dict produced
    by `object_positions.get_object_positions()`, converting it to a numpy
    array and verifying the key exists.
    """
    if name not in positions:
        raise KeyError(f"[object_positions] Missing key '{name}'")
    return np.asarray(positions[name], dtype=float)


# --------------------------------------------------------------------------- #
#                               Main Task Logic                               #
# --------------------------------------------------------------------------- #
def run_task() -> None:
    """
    Execute the seven-step oracle plan described in the specification:

        1. rotate  (gripper  0°  →  90° about world-Z)
        2. move    (to bottom_side_pos)
        3. move    (to bottom_anchor_pos)
        4. pick    (grasp drawer handle)
        5. pull    (open drawer)
        6. pick    (lift trash from table)
        7. place   (throw trash into bin)

    Only the predefined skills from `skill_code` are used.
    """
    print("\n================  RUN_TASK START  ================\n")

    # ------------------------------ Environment ------------------------------ #
    env, task = setup_environment()
    try:
        # Initialise episode and recording helpers
        descriptions, obs = task.reset()
        init_video_writers(obs)

        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------------------- World State ---------------------------- #
        positions = get_object_positions()        # {name: (x, y, z), …}

        # Locations required by the oracle plan
        side_pos   = _get_pos(positions, "bottom_side_pos")
        handle_pos = _get_pos(positions, "bottom_anchor_pos")
        bin_pos    = _get_pos(positions, "bin")

        # Detect a single trash object that must be disposed of
        if "rubbish" in positions:
            trash_key = "rubbish"
        elif "trash" in positions:
            trash_key = "trash"
        elif "item3" in positions:
            trash_key = "item3"
        else:
            # Fallback: first object that looks like an 'item'
            candidates = [k for k in positions if k.startswith("item")]
            if not candidates:
                raise RuntimeError("No trash object found in object_positions!")
            trash_key = candidates[0]
        trash_pos = _get_pos(positions, trash_key)

        # ------------------------  Oracle Execution  ------------------------- #

        # STEP-1 : rotate gripper to 90°
        print("\n[Step-1] rotate  →  90° (about +Z)")
        target_quat = R.from_euler("z", 90, degrees=True).as_quat()
        obs, reward, done = _safe_call(
            rotate, env, task,
            target_quat=target_quat,
            max_steps=120, threshold=0.05, timeout=10.0
        )
        if done:
            print("[Early Exit] Task finished immediately after step-1")
            return

        # STEP-2 : move to drawer’s side position (pre-aligned orientation)
        print("\n[Step-2] move  →  bottom_side_pos")
        obs, reward, done = _safe_call(
            move, env, task,
            target_pos=side_pos,
            max_steps=120
        )
        if done:
            print("[Early Exit] Task finished after step-2")
            return

        # STEP-3 : move straight to drawer anchor position (handle)
        print("\n[Step-3] move  →  bottom_anchor_pos")
        obs, reward, done = _safe_call(
            move, env, task,
            target_pos=handle_pos,
            max_steps=120
        )
        if done:
            print("[Early Exit] Task finished after step-3")
            return

        # STEP-4 : pick the drawer handle
        print("\n[Step-4] pick  –  grasp drawer handle")
        obs, reward, done = _safe_call(
            pick, env, task,
            target_pos=handle_pos,
            approach_distance=0.05,
            approach_axis="z"
        )
        if done:
            print("[Early Exit] Task finished after step-4")
            return

        # STEP-5 : pull the drawer open (0.20 m along −X)
        print("\n[Step-5] pull  –  open drawer (0.20 m along −x)")
        obs, reward, done = _safe_call(
            pull, env, task,
            pull_distance=0.20,
            pull_axis="-x",
            max_steps=120
        )
        if done:
            print("[Early Exit] Task finished after step-5")
            return

        # STEP-6 : pick up the trash object
        print(f"\n[Step-6] pick  –  trash object: '{trash_key}'")
        obs, reward, done = _safe_call(
            pick, env, task,
            target_pos=trash_pos,
            approach_distance=0.15,
            approach_axis="z"
        )
        if done:
            print("[Early Exit] Task finished after step-6")
            return

        # STEP-7 : place trash into bin
        print("\n[Step-7] place –  trash into bin")
        obs, reward, done = _safe_call(
            place, env, task,
            target_pos=bin_pos,
            approach_distance=0.15,
            approach_axis="-z"
        )

        # -----------------------------  Outcome ----------------------------- #
        if done:
            print("\n[SUCCESS] Goal achieved!  Reward:", reward)
        else:
            print("\n[INFO] Plan executed, but environment returned done=False")

    finally:
        shutdown_environment(env)
        print("\n================  RUN_TASK END  =================\n")


if __name__ == "__main__":
    run_task()