import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment

# Import every predefined skill exactly as given
from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _safe_pos(positions, key_list):
    """
    Helper that tries each key in key_list and returns the first
    position that exists.  Raises KeyError if none are found.
    """
    for k in key_list:
        if k in positions and positions[k] is not None:
            return positions[k]
    raise KeyError(f"None of the keys {key_list} found in object_positions.")


def run_task_open_drawer_and_dispose():
    """
    1.  Rotate the gripper (90° about the Z‑axis – satisfies the ‘rotate’ step).
    2.  Move to the drawer‑side pose (bottom drawer by preference).
    3.  Move to the anchor (handle) pose.
    4.  Pick (grasp) the drawer handle.
    5.  Pull the drawer open.
    6.  Move to the rubbish object (item3).
    7.  Pick the rubbish.
    8.  Move over the bin.
    9.  Place/Drop the rubbish into the bin.
    """
    print("==========  [TASK]  START  ==========")

    # ------------------------------------------------------------------
    #  Environment initialisation
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset to the provided initial state
        descriptions, obs = task.reset()

        # Optional video‑recorder wrapper
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  Retrieve static object positions from the helper
        # ------------------------------------------------------------------
        positions = get_object_positions()

        # Drawer choice – try bottom, else middle, else top
        drawer_suffix = ""
        for cand in ["bottom", "middle", "top"]:
            side_key = f"{cand}_side_pos"
            anchor_key = f"{cand}_anchor_pos"
            if side_key in positions and anchor_key in positions:
                drawer_suffix = cand
                break
        if drawer_suffix == "":
            raise RuntimeError("No drawer positions available!")

        side_pos   = positions[f"{drawer_suffix}_side_pos"]
        anchor_pos = positions[f"{drawer_suffix}_anchor_pos"]

        # Rubbish object (item3) and bin target
        rubbish_pos = _safe_pos(positions, ["item3", "rubbish"])
        bin_pos     = _safe_pos(positions, ["bin"])

        # ------------------------------------------------------------------
        #  PLAN EXECUTION  (follows the 9‑step specification exactly)
        # ------------------------------------------------------------------

        # STEP 1 – rotate (about Z 90°)
        target_quat = R.from_euler("z", 90, degrees=True).as_quat()
        print("\n[STEP 1] rotate → 90° about Z")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("Task finished during step 1!")
            return

        # STEP 2 – move to drawer side pose
        print("\n[STEP 2] move → drawer side pose:", side_pos)
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("Task finished during step 2!")
            return

        # STEP 3 – move to drawer anchor pose
        print("\n[STEP 3] move → drawer anchor pose:", anchor_pos)
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("Task finished during step 3!")
            return

        # STEP 4 – pick the drawer handle
        print("\n[STEP 4] pick → drawer handle (anchor pose)")
        obs, reward, done = pick(env, task, anchor_pos, approach_axis='-z')
        if done:
            print("Task finished during step 4!")
            return

        # STEP 5 – pull the drawer open (10 cm along +x)
        print("\n[STEP 5] pull → open drawer (0.10 m along +x)")
        obs, reward, done = pull(env, task, pull_distance=0.10, pull_axis='x')
        if done:
            print("Task finished during step 5!")
            return

        # STEP 6 – move to rubbish object
        print("\n[STEP 6] move → rubbish (item3) pose:", rubbish_pos)
        obs, reward, done = move(env, task, rubbish_pos + np.array([0.0, 0.0, 0.10]))  # hover 10 cm above
        if done:
            print("Task finished during step 6!")
            return

        # STEP 7 – pick the rubbish
        print("\n[STEP 7] pick → rubbish (item3)")
        obs, reward, done = pick(env, task, rubbish_pos, approach_axis='-z')
        if done:
            print("Task finished during step 7!")
            return

        # STEP 8 – move over bin
        print("\n[STEP 8] move → bin pose:", bin_pos + np.array([0.0, 0.0, 0.10]))
        obs, reward, done = move(env, task, bin_pos + np.array([0.0, 0.0, 0.10]))
        if done:
            print("Task finished during step 8!")
            return

        # STEP 9 – place/drop rubbish into bin
        print("\n[STEP 9] place → rubbish into bin")
        obs, reward, done = place(env, task, bin_pos, approach_axis='-z')
        if done:
            print("Task finished during step 9!")
        else:
            print("Task completed (done flag False but plan finished).")

    finally:
        shutdown_environment(env)

    print("==========  [TASK]  END  ==========")


if __name__ == "__main__":
    run_task_open_drawer_and_dispose()