import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor   # noqa: F401  (kept for completeness)

from env import setup_environment, shutdown_environment
from skill_code import move, rotate, pick, place, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ------------------------------------------------------------------
# Helper utilities
# ------------------------------------------------------------------
def _safe_get_shape_pos(name: str):
    """Return the position of a Shape if it exists in the scene; else None."""
    try:
        return np.array(Shape(name).get_position())
    except Exception:
        return None


def fetch_position(name: str, positions: dict):
    """
    Robustly obtain the 3-D world position of the given object / landmark.

    Strategy:
      1) Direct lookup in the supplied dictionary.
      2) Try a small set of underscore / hyphen substitutions.
      3) Query the scene directly for a Shape of that name.

    Raises:
        KeyError if the position cannot be resolved.
    """
    # 1) direct hit
    if name in positions:
        return np.array(positions[name])

    # 2) simple name variants
    variants = {
        name,
        name.replace('-', '_'),
        name.replace('_', '-'),
        name.replace(' ', '_'),
        name.replace(' ', '-')
    }
    for v in variants:
        if v in positions:
            return np.array(positions[v])

    # 3) direct scene query
    pos = _safe_get_shape_pos(name)
    if pos is not None:
        return pos

    raise KeyError(f"[fetch_position] Unable to find position for “{name}”. "
                   "Ensure the object exists in the scene or in object_positions.")


def fetch_quaternion(angle_name: str):
    """
    Retrieve a quaternion (xyzw) for a named orientation.
    Tries to read from a Shape first, else returns hard-coded constants.
    """
    # Attempt to read from scene
    try:
        return np.array(Shape(angle_name).get_quaternion())
    except Exception:
        pass

    # Fallback constants (xyzw)
    if angle_name == 'zero_deg':
        return np.array([0., 0., 0., 1.])
    if angle_name == 'ninety_deg':
        s = np.sqrt(2) / 2.0        # 90° about Z-axis
        return np.array([0., 0., s, s])

    raise ValueError(f"[fetch_quaternion] Unknown angle identifier “{angle_name}”. "
                     "Add a constant here or put a dummy Shape in the scene.")


# ------------------------------------------------------------------
# Oracle-plan executor following the provided Specification
# ------------------------------------------------------------------
def run_oracle_plan():
    print("========== [RUN] Oracle Plan Execution ==========")

    # -------------------------------------------------
    # Environment initialisation
    # -------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        # optional video capture
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -------------------------------------------------
        # Gather positions / orientations we need
        # -------------------------------------------------
        positions = get_object_positions()

        # --- Drawer related landmarks (bottom drawer) ---
        side_pos_bottom   = fetch_position('side-pos-bottom',   positions)   # step-2 target
        anchor_pos_bottom = fetch_position('anchor-pos-bottom', positions)   # step-3 target

        # Handle grasp point for the drawer
        # The scene sometimes names it “bottom_anchor_pos”; fall back to anchor-pos-bottom
        try:
            drawer_handle_pos = fetch_position('bottom_anchor_pos', positions)
        except KeyError:
            drawer_handle_pos = anchor_pos_bottom

        # --- Tomatoes & plate ---
        tomato1_pos = fetch_position('tomato1', positions)
        tomato2_pos = fetch_position('tomato2', positions)
        plate_pos   = fetch_position('plate',   positions)

        ninety_quat = fetch_quaternion('ninety_deg')

        # -------------------------------------------------
        # Helper to abort if the simulation finished
        # -------------------------------------------------
        def ensure_running(done_flag: bool):
            if done_flag:
                raise RuntimeError("[RUN] Task terminated prematurely (done=True).")

        # -------------------------------------------------
        # === Execute the oracle plan step-by-step ===
        # -------------------------------------------------
        #
        #  Step 1 : rotate gripper to 90 deg about Z
        #  Step 2 : move to the side position of the bottom drawer
        #  Step 3 : move to the anchor position of the bottom drawer
        #  Step 4 : grasp (pick) the drawer handle
        #  Step 5 : pull the drawer open (along +X)
        #  Step 6 : pick tomato 1
        #  Step 7 : place tomato 1 on plate
        #  Step 8 : pick tomato 2
        #  Step 9 : place tomato 2 on plate
        #

        # STEP-1 -------------------------------------------------------
        obs, reward, done = rotate(env, task, ninety_quat)
        ensure_running(done)

        # STEP-2 -------------------------------------------------------
        obs, reward, done = move(env, task, side_pos_bottom)
        ensure_running(done)

        # STEP-3 -------------------------------------------------------
        obs, reward, done = move(env, task, anchor_pos_bottom)
        ensure_running(done)

        # STEP-4 -------------------------------------------------------
        obs, reward, done = pick(
            env, task,
            target_pos=drawer_handle_pos,
            approach_distance=0.10,
            approach_axis='-y'          # approach from in-front of the drawer
        )
        ensure_running(done)

        # STEP-5 -------------------------------------------------------
        obs, reward, done = pull(
            env, task,
            pull_distance=0.12,          # 12 cm pull to ensure drawer opens
            pull_axis='x'
        )
        ensure_running(done)

        # STEP-6 -------------------------------------------------------
        obs, reward, done = pick(
            env, task,
            target_pos=tomato1_pos,
            approach_distance=0.15,
            approach_axis='-z'
        )
        ensure_running(done)

        # STEP-7 -------------------------------------------------------
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_distance=0.15,
            approach_axis='-z'
        )
        ensure_running(done)

        # STEP-8 -------------------------------------------------------
        obs, reward, done = pick(
            env, task,
            target_pos=tomato2_pos,
            approach_distance=0.15,
            approach_axis='-z'
        )
        ensure_running(done)

        # STEP-9 -------------------------------------------------------
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_distance=0.15,
            approach_axis='-z'
        )
        ensure_running(done)

        print("[RUN] Oracle plan executed successfully – goal achieved!")

    finally:
        shutdown_environment(env)
        print("========== [RUN] Environment shut down ==========")


if __name__ == "__main__":
    run_oracle_plan()