import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Import every predefined skill exactly as they are
from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

# Utility that is expected to be provided (returns {name: np.ndarray([x, y, z]), ...})
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
# Helper utilities                                                            #
# --------------------------------------------------------------------------- #

def quat_from_euler(roll: float, pitch: float, yaw: float, seq: str = 'xyz'):
    """Return a [x, y, z, w] quaternion from Euler angles (degrees)."""
    return R.from_euler(seq, [roll, pitch, yaw], degrees=True).as_quat()


def fetch_position(name: str, cache: dict):
    """
    Fetch an object’s world‑space position.  
    Priority: 
      1) cached dictionary returned by get_object_positions()  
      2) direct PyRep lookup via Shape(name)  
    Raises KeyError if the object cannot be located.
    """
    if name in cache:
        return np.asarray(cache[name], dtype=np.float32)

    try:
        return np.asarray(Shape(name).get_position(), dtype=np.float32)
    except Exception as e:
        raise KeyError(f"[fetch_position] Cannot find position for '{name}': {e}")


# --------------------------------------------------------------------------- #
# Main task logic                                                             #
# --------------------------------------------------------------------------- #

def run_skeleton_task():
    """Execute the oracle plan end‑to‑end using only predefined skill calls."""
    print("==========  STARTING TASK  ==========")

    # ------------------------------------------------------------------ #
    #  Environment initialisation                                        #
    # ------------------------------------------------------------------ #
    env, task = setup_environment()
    try:
        # Ensure deterministic start
        descriptions, obs = task.reset()

        # Optional video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------ #
        #  Retrieve all relevant positions                                   #
        # ------------------------------------------------------------------ #
        cached_positions = get_object_positions()

        pos_bottom_side   = fetch_position('bottom_side_pos',   cached_positions)
        pos_bottom_anchor = fetch_position('bottom_anchor_pos', cached_positions)

        pos_item1 = fetch_position('item1', cached_positions)   # tomato1
        pos_item2 = fetch_position('item2', cached_positions)   # tomato2
        pos_plate = fetch_position('plate', cached_positions)

        # ------------------------------------------------------------------ #
        #  Execute the oracle plan                                           #
        # ------------------------------------------------------------------ #

        # STEP‑1  rotate(gripper, ninety_deg)
        #   – rotate 90° around Z axis
        target_quat = quat_from_euler(0, 0, 90)   # [x,y,z,w]
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Early‑Exit] Task ended after rotate.")
            return

        # STEP‑2  move(gripper, side‑pos‑bottom)
        obs, reward, done = move(env, task, pos_bottom_side)
        if done:
            print("[Early‑Exit] Task ended after move‑to‑side.")
            return

        # STEP‑3  move(gripper, anchor‑pos‑bottom)
        obs, reward, done = move(env, task, pos_bottom_anchor)
        if done:
            print("[Early‑Exit] Task ended after move‑to‑anchor.")
            return

        # STEP‑4  pick(gripper, bottom)    → grasp bottom drawer handle
        obs, reward, done = pick(
            env,
            task,
            target_pos=pos_bottom_anchor,
            approach_distance=0.05,
            max_steps=120,
            threshold=0.005,
            approach_axis='z',
            timeout=15.0
        )
        if done:
            print("[Early‑Exit] Task ended during drawer pick.")
            return

        # STEP‑5  pull(gripper, bottom)
        #   – pull 0.20 m along +X (tune if necessary)
        obs, reward, done = pull(
            env,
            task,
            pull_distance=0.20,
            pull_axis='x',
            max_steps=120,
            threshold=0.005,
            timeout=15.0
        )
        if done:
            print("[Early‑Exit] Task ended during pull.")
            return

        # ------------------------------------------------------------------ #
        #  Handle two tomatoes (item1 & item2)                               #
        # ------------------------------------------------------------------ #
        for item_name, item_pos in [('item1', pos_item1), ('item2', pos_item2)]:
            print(f"\n===== Processing {item_name} =====")

            # pick(tomatoX, table)
            obs, reward, done = pick(
                env,
                task,
                target_pos=item_pos,
                approach_distance=0.15,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=15.0
            )
            if done:
                print(f"[Early‑Exit] Task ended while picking {item_name}.")
                return

            # place(tomatoX, plate)
            obs, reward, done = place(
                env,
                task,
                target_pos=pos_plate,
                approach_distance=0.15,
                max_steps=120,
                threshold=0.01,
                approach_axis='-z',
                timeout=15.0
            )
            if done:
                print(f"[Early‑Exit] Task ended while placing {item_name}.")
                return

        print("\n==========  TASK COMPLETED SUCCESSFULLY!  ==========")

    except Exception as exc:
        print(f"[ERROR] Exception during task execution: {exc}")

    finally:
        shutdown_environment(env)
        print("==========  ENVIRONMENT SHUTDOWN  ==========")


if __name__ == "__main__":
    run_skeleton_task()
