import numpy as np
from scipy.spatial.transform import Rotation as R

# the following two imports must stay even if they look unused – they initialise
# objects/sensors that RLBench relies on internally.
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# low-level manipulation primitives (pick, place, move, rotate, pull, …)
from skill_code import *        # DO NOT re-implement these skills!

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
#  Convenience helper
# ---------------------------------------------------------------------------
def _safe_lookup(pos_dict, alternatives):
    """
    Return the first existing key from `alternatives` inside `pos_dict`
    (converted to a NumPy vector).  Raise KeyError if none of the keys exist.
    """
    for key in alternatives:
        if key in pos_dict and pos_dict[key] is not None:
            return np.asarray(pos_dict[key], dtype=float)
    raise KeyError(f"None of {alternatives} were found in the position "
                   f"dictionary.  Available: {list(pos_dict.keys())}")


# ---------------------------------------------------------------------------
#  Main logic – combined ‘open drawer & dispose rubbish’ task
# ---------------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------
    # 1) environment initialisation
    # ------------------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        # optional video logging ------------------------------------------------
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)
        # ----------------------------------------------------------------------

        # ------------------------------------------------------------
        # 2) query helper for all relevant object positions
        # ------------------------------------------------------------
        try:
            positions = get_object_positions()
        except Exception as exc:
            print("[Error] Could not obtain object positions:", exc)
            return

        # candidate drawers in preferred order: bottom → middle → top
        drawer_candidates = [
            ("bottom_side_pos",  "bottom_anchor_pos"),
            ("middle_side_pos",  "middle_anchor_pos"),
            ("top_side_pos",     "top_anchor_pos")
        ]

        # rubbish item & disposal bin
        rubbish_pos = _safe_lookup(
            positions,
            ["item3", "item2", "item1", "rubbish", "trash"]
        )
        bin_pos = _safe_lookup(
            positions,
            ["bin", "trash_bin", "waste_bin", "disposal_bin"]
        )

        # ------------------------------------------------------------
        # 3) oracle plan execution (matches given Specification)
        # ------------------------------------------------------------
        done = False

        # ---- Step-1  rotate (align gripper 90° about Z) -----------------------
        target_quat = R.from_euler("z", 90.0, degrees=True).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Plan] Finished after rotate – task already done?")
            return

        # ---- Steps-2 … 5 : open an unlocked drawer ---------------------------
        drawer_opened = False
        for side_key, anchor_key in drawer_candidates:
            try:
                side_pos   = _safe_lookup(positions, [side_key])
                anchor_pos = _safe_lookup(positions, [anchor_key])
            except KeyError:
                continue   # missing position information, try next drawer

            try:
                # Step-2  move-to-side
                obs, reward, done = move(env, task, side_pos)
                if done: return

                # Step-3  move-to-anchor
                obs, reward, done = move(env, task, anchor_pos)
                if done: return

                # Step-4  pick-drawer  (reuse generic ‘pick’)
                obs, reward, done = pick(env, task, target_pos=anchor_pos)
                if done: return

                # Step-5  pull  (open along −X by 0.20 m)
                obs, reward, done = pull(env, task,
                                         pull_distance=0.20,
                                         pull_axis="-x")
                if done: return

                drawer_opened = True
                # release the handle so we can pick rubbish later
                obs, reward, done = place(env, task, target_pos=anchor_pos)
                if done: return

                print(f"[Plan] Successfully opened drawer ‘{anchor_key}’.")
                break   # stop trying other drawers
            except Exception as exc:
                # Something failed (e.g. drawer locked) – try next candidate
                print(f"[Warning] Drawer {anchor_key} failed: {exc}")
                # attempt to open gripper if still holding the handle
                try:
                    place(env, task, target_pos=anchor_pos)
                except Exception:
                    pass
                continue

        if not drawer_opened:
            print("[Error] Could not open any drawer – aborting.")
            return

        # ---- Step-6  pick rubbish --------------------------------------------
        # Move above the rubbish first (helps if it’s inside the drawer)
        obs, reward, done = move(env, task, rubbish_pos)
        if done: return
        obs, reward, done = pick(env, task, target_pos=rubbish_pos)
        if done: return

        # ---- Step-7  place rubbish into the bin ------------------------------
        obs, reward, done = move(env, task, bin_pos)
        if done: return
        obs, reward, done = place(env, task, target_pos=bin_pos)

        # ----------------------------------------------------------------------
        if done:
            print("[Plan] Task completed successfully!  Reward:", reward)
        else:
            print("[Plan] Plan finished but environment did not signal done.")
    except Exception as exc:
        print("[Error] Exception during task execution:", exc)
    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


# ---------------------------------------------------------------------------
#  Entrypoint
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    run_skeleton_task()