import numpy as np
from scipy.spatial.transform import Rotation as R

# Environment / helper modules delivered with the benchmark
from env import setup_environment, shutdown_environment
from skill_code import move, pick, place, rotate, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------#
# Utility helpers                                                            #
# ---------------------------------------------------------------------------#
def _lookup_position(name_candidates, positions_dict):
    """
    Return the first valid 3-D position found for any of the keys given in
    `name_candidates`.  
    • If the dictionary value is already an np.ndarray / list / tuple of length 3,
      it is returned directly as a NumPy array.  
    • If the value is a PyRep Shape (or similar) object, the current world
      position is queried via `.get_position()`.  
    Raises
    ------
    KeyError
        If none of the candidate keys exist or all corresponding values are
        `None`.
    """
    for key in name_candidates:
        if key in positions_dict and positions_dict[key] is not None:
            val = positions_dict[key]
            # Convert to numpy array if necessary
            if isinstance(val, (list, tuple, np.ndarray)):
                return np.asarray(val, dtype=float)
            # Lazy check for PyRep Shape-like API
            if hasattr(val, "get_position"):
                return np.asarray(val.get_position(), dtype=float)
    raise KeyError(f"[lookup] None of the keys found in list: {name_candidates}")


# ---------------------------------------------------------------------------#
# Main task logic                                                            #
# ---------------------------------------------------------------------------#
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # -----------------------------------------------------------------------#
    # Environment setup                                                      #
    # -----------------------------------------------------------------------#
    env, task = setup_environment()
    try:
        # Reset episode
        descriptions, obs = task.reset()

        # Optional video writer for debugging / grading
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -------------------------------------------------------------------#
        # Query positions of all relevant objects                            #
        # -------------------------------------------------------------------#
        positions = get_object_positions()

        # Drawer positions (we choose the “bottom” drawer for this run)
        side_pos_bottom   = _lookup_position(
            ["bottom_side_pos", "side-pos-bottom", "bottom_side", "side_pos_bottom"], positions
        )
        anchor_pos_bottom = _lookup_position(
            ["bottom_anchor_pos", "anchor-pos-bottom", "bottom_anchor", "anchor_pos_bottom"], positions
        )

        # Rubbish & bin
        rubbish_pos = _lookup_position(
            ["rubbish", "item3", "item_3"], positions
        )
        bin_pos = _lookup_position(
            ["bin", "trash_bin", "trash"], positions
        )

        # -------------------------------------------------------------------#
        # Oracle plan execution (7 steps)                                    #
        # -------------------------------------------------------------------#

        # 1) rotate 0 → 90 deg about Z axis
        print("\n[PLAN] Step 1 — Rotate gripper 90° about Z-axis")
        ninety_deg_quat = R.from_euler("xyz", [0, 0, np.deg2rad(90)]).as_quat()
        obs, reward, done = rotate(env, task, ninety_deg_quat)
        if done:
            print("[Early Exit] Task finished during rotate.")
            return

        # 2) move-to-side of bottom drawer
        print("\n[PLAN] Step 2 — Move to drawer side position")
        obs, reward, done = move(env, task, side_pos_bottom)
        if done:
            print("[Early Exit] Task finished during move-to-side.")
            return

        # 3) move-to-anchor (handle) of drawer
        print("\n[PLAN] Step 3 — Move to drawer anchor (handle)")
        obs, reward, done = move(env, task, anchor_pos_bottom)
        if done:
            print("[Early Exit] Task finished during move-to-anchor.")
            return

        # 4) pick-drawer (close gripper on handle)
        print("\n[PLAN] Step 4 — Grasp the drawer handle")
        obs, reward, done = pick(
            env,
            task,
            target_pos=anchor_pos_bottom,
            approach_distance=0.05,   # small approach distance
            approach_axis='z'
        )
        if done:
            print("[Early Exit] Task finished while grasping handle.")
            return

        # 5) pull drawer fully open
        print("\n[PLAN] Step 5 — Pull the drawer open")
        obs, reward, done = pull(
            env,
            task,
            pull_distance=0.20,       # positive X-direction pull
            pull_axis='x'
        )
        if done:
            print("[Early Exit] Task finished during pull.")
            return

        # 6) pick rubbish from the table
        print("\n[PLAN] Step 6 — Pick the rubbish object")
        # Hover above first for safer descent
        pre_grasp_pos = rubbish_pos + np.array([0.0, 0.0, 0.10])
        obs, reward, done = move(env, task, pre_grasp_pos)
        if done:
            print("[Early Exit] Task finished before picking rubbish.")
            return
        obs, reward, done = pick(
            env,
            task,
            target_pos=rubbish_pos,
            approach_distance=0.15,
            approach_axis='z'
        )
        if done:
            print("[Early Exit] Task finished during rubbish pick.")
            return

        # 7) place rubbish into the bin
        print("\n[PLAN] Step 7 — Place rubbish into bin")
        pre_place_pos = bin_pos + np.array([0.0, 0.0, 0.15])
        obs, reward, done = move(env, task, pre_place_pos)
        if done:
            print("[Early Exit] Task finished before placing rubbish.")
            return
        obs, reward, done = place(
            env,
            task,
            target_pos=bin_pos,
            approach_distance=0.15,
            approach_axis='z'
        )

        # -------------------------------------------------------------------#
        # Task outcome                                                       #
        # -------------------------------------------------------------------#
        if done:
            print("[SUCCESS] Goal accomplished!  Reward:", reward)
        else:
            print("[INFO] All plan steps executed; environment signalled done=False.")

    except KeyError as e:
        # Most common failure is a missing object key – print nicely for debugging
        print("[ERROR] Position lookup failed:", e)

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()