import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# NOTE: we assume the skill functions (move, pick, place, rotate, pull, …) are
# already available in the module `skill_code` which is auto‑imported by the
# execution framework.  If that is not the case, simply ensure
# “from skill_code import *” is executed somewhere before this file is run.
from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def _safe_call(func, *args, **kwargs):
    """Invoke a skill function with basic exception handling."""
    try:
        return func(*args, **kwargs)
    except Exception as exc:                 # pylint: disable=broad-except
        print(f"[ERROR] Exception during '{func.__name__}': {exc}")
        raise


def run_skeleton_task():
    """
    Execute the oracle plan described in the specification:

      1. move   gripper → side‑pos‑bottom
      2. rotate gripper → ninety_deg
      3. move   gripper → anchor‑pos‑bottom
      4. pick   (drawer handle at bottom drawer)
      5. pull   (open bottom drawer)
      6. move   gripper → table (rubbish location)
      7. pick   rubbish  at table
      8. move   gripper → bin
      9. place  rubbish  into bin
    """
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    #  Environment setup
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)

        # Wrap task step / get_observation so that video frames are recorded.
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  Retrieve all object positions that we might need
        # ------------------------------------------------------------------
        positions = get_object_positions()
        # Helper for concise access
        def p(name):
            pos = positions.get(name)
            if pos is None:
                raise KeyError(f"[run_skeleton_task] No position for object '{name}'")
            return np.asarray(pos, dtype=float)

        # ------------------------------------------------------------------
        #  STEP‑1: move → side‑pos‑bottom
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑1] move gripper to side‑pos‑bottom")
        obs, reward, done = _safe_call(move, env, task, target_pos=p("bottom_side_pos"))
        if done:
            print("[Task] Finished early after step‑1.")
            return

        # ------------------------------------------------------------------
        #  STEP‑2: rotate → ninety_deg  (about the world‑Z axis)
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑2] rotate gripper to 90°")
        target_quat = R.from_euler("z", 90, degrees=True).as_quat()  # xyzw
        obs, reward, done = _safe_call(rotate, env, task, target_quat=target_quat)
        if done:
            print("[Task] Finished early after step‑2.")
            return

        # ------------------------------------------------------------------
        #  STEP‑3: move → anchor‑pos‑bottom
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑3] move gripper to anchor‑pos‑bottom")
        obs, reward, done = _safe_call(move, env, task, target_pos=p("bottom_anchor_pos"))
        if done:
            print("[Task] Finished early after step‑3.")
            return

        # ------------------------------------------------------------------
        #  STEP‑4: pick drawer handle (bottom drawer)
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑4] pick bottom drawer handle")
        obs, reward, done = _safe_call(
            pick, env, task, target_pos=p("bottom_anchor_pos"),
            approach_distance=0.05,        # smaller approach for handle
            approach_axis="-y"             # arbitrary; adjust if needed
        )
        if done:
            print("[Task] Finished early after step‑4.")
            return

        # ------------------------------------------------------------------
        #  STEP‑5: pull the drawer open
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑5] pull bottom drawer")
        anchor = p("bottom_anchor_pos")
        side   = p("bottom_side_pos")
        diff   = side - anchor
        # Use dominant axis of the anchor→side vector
        axis_char = "x" if abs(diff[0]) >= abs(diff[1]) else "y"
        # Decide sign
        axis_char = axis_char if diff[0] >= 0 or (axis_char == "y" and diff[1] >= 0) else f"-{axis_char}"
        pull_distance = float(np.linalg.norm(diff))
        obs, reward, done = _safe_call(
            pull, env, task,
            pull_distance=pull_distance,
            pull_axis=axis_char
        )
        if done:
            print("[Task] Finished early after step‑5.")
            return

        # ------------------------------------------------------------------
        #  STEP‑6: move above the table (towards rubbish)
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑6] move gripper to table / rubbish area")
        # If a dedicated 'table' key exists, use that; otherwise use rubbish pos
        table_pos = positions.get("table", positions.get("rubbish", p("item3")))
        table_pos = np.asarray(table_pos, dtype=float)
        # Raise Z a bit to avoid collision
        table_target = table_pos + np.array([0, 0, 0.10])
        obs, reward, done = _safe_call(move, env, task, target_pos=table_target)
        if done:
            print("[Task] Finished early after step‑6.")
            return

        # ------------------------------------------------------------------
        #  STEP‑7: pick rubbish from the table
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑7] pick rubbish")
        # Determine rubbish object position; fall back to item3 if unknown
        rubbish_key = "rubbish" if "rubbish" in positions else "item3"
        obs, reward, done = _safe_call(
            pick, env, task,
            target_pos=p(rubbish_key),
            approach_distance=0.15,
            approach_axis="z"
        )
        if done:
            print("[Task] Finished early after step‑7.")
            return

        # ------------------------------------------------------------------
        #  STEP‑8: move to bin location
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑8] move gripper to bin")
        bin_pos = p("bin") + np.array([0, 0, 0.15])  # hover above the bin
        obs, reward, done = _safe_call(move, env, task, target_pos=bin_pos)
        if done:
            print("[Task] Finished early after step‑8.")
            return

        # ------------------------------------------------------------------
        #  STEP‑9: place rubbish into the bin
        # ------------------------------------------------------------------
        print("\n[Plan‑Step‑9] place rubbish into bin")
        obs, reward, done = _safe_call(
            place, env, task,
            target_pos=p("bin") + np.array([0, 0, 0.05]),
            approach_distance=0.10,
            approach_axis="-z"
        )

        # ------------------------------------------------------------------
        #  Task completion summary
        # ------------------------------------------------------------------
        if done:
            print(f"[Task] Completed successfully! Final reward = {reward}")
        else:
            print("[Task] Plan executed.  Environment reports done = False.")

    finally:
        # Always attempt to shut down the environment, even if something failed
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()