import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# === Pre-defined low-level skills ===
from skill_code import rotate, move, pick, pull, place


# ------------------------------------------------------------
# Utility helpers
# ------------------------------------------------------------
def require_pos(pos_dict, key):
    """Return the 3-vector position for a given key or raise a clear error."""
    if key not in pos_dict:
        raise KeyError(f"[run_task] Missing position for key: '{key}'")
    return np.asarray(pos_dict[key], dtype=np.float32)


# ------------------------------------------------------------
# Main high-level routine that follows the oracle plan
# ------------------------------------------------------------
def run_task():
    print("\n=================  RUN TASK – START  =================")
    env, task = setup_environment()

    try:
        # ---------- Reset & optional recording ----------
        _, obs = task.reset()
        init_video_writers(obs)                      # safe-no-op if video disabled

        # Wrap step / get_observation for optional video recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------- Gather all required world positions ----------
        positions = get_object_positions()

        bottom_side_pos   = require_pos(positions, "bottom_side_pos")
        bottom_anchor_pos = require_pos(positions, "bottom_anchor_pos")

        tomato1_pos = require_pos(positions, "item1")
        tomato2_pos = require_pos(positions, "item2")
        plate_pos   = require_pos(positions, "plate")

        # ---------- Oracle-level constants ----------
        PULL_DISTANCE = 0.15       # metres – tuned to drawer depth
        PULL_AXIS     = 'x'        # +x assumed to be outward from cabinet

        # ========================================================
        # Oracle plan – 9 steps (matches specification)
        # ========================================================

        # Step 1 – rotate gripper from zero_deg → ninety_deg
        print("\n[Step 1] rotate gripper ⇒ 90° about z-axis")
        quat_ninety_deg = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
        obs, reward, done = rotate(env, task, target_quat=quat_ninety_deg)
        if done:
            print("[Task] Finished early during rotate."); return

        # Step 2 – move to drawer side position
        print("\n[Step 2] move ⇒ bottom_side_pos")
        obs, reward, done = move(env, task, target_pos=bottom_side_pos)
        if done:
            print("[Task] Finished early during move-to-side."); return

        # Step 3 – move to drawer anchor (handle) position
        print("\n[Step 3] move ⇒ bottom_anchor_pos")
        obs, reward, done = move(env, task, target_pos=bottom_anchor_pos)
        if done:
            print("[Task] Finished early during move-to-anchor."); return

        # Step 4 – pick drawer handle
        print("\n[Step 4] pick ⇒ drawer handle (bottom)")
        obs, reward, done = pick(
            env, task,
            target_pos=bottom_anchor_pos,
            approach_distance=0.12,
            approach_axis='-z'         # approach vertically downward
        )
        if done:
            print("[Task] Finished early during drawer-pick."); return

        # Step 5 – pull drawer open
        print("\n[Step 5] pull ⇒ open drawer")
        obs, reward, done = pull(
            env, task,
            pull_distance=PULL_DISTANCE,
            pull_axis=PULL_AXIS
        )
        if done:
            print("[Task] Finished early during pull."); return

        # Step 6 – pick first tomato
        print("\n[Step 6] pick ⇒ tomato1 (item1)")
        obs, reward, done = pick(
            env, task,
            target_pos=tomato1_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )
        if done:
            print("[Task] Finished early during pick tomato1."); return

        # Step 7 – place tomato1 onto plate
        print("\n[Step 7] place ⇒ tomato1 onto plate")
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )
        if done:
            print("[Task] Finished early during place tomato1."); return

        # Step 8 – pick second tomato
        print("\n[Step 8] pick ⇒ tomato2 (item2)")
        obs, reward, done = pick(
            env, task,
            target_pos=tomato2_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )
        if done:
            print("[Task] Finished early during pick tomato2."); return

        # Step 9 – place tomato2 onto plate
        print("\n[Step 9] place ⇒ tomato2 onto plate")
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_distance=0.12,
            approach_axis='-z'
        )

        # ---------- Final status ----------
        if done:
            print(f"[Task] Environment signalled done. Reward: {reward}")
        else:
            print("[Task] Oracle plan finished. (done flag not set)")

    except Exception as exc:
        print(f"[Task] Exception encountered: {exc}")
        raise
    finally:
        shutdown_environment(env)
        print("=================  RUN TASK – END    ================")


if __name__ == "__main__":
    run_task()