# run_skeleton_task.py
#
# Concrete implementation that follows the seven-step oracle plan given in the
# specification.  The code relies **only** on the predefined primitives shipped
# in `skill_code.py`; no new low-level skills are introduced.

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape                    # kept (may aid debug)
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# import every predefined primitive/skill exactly as delivered
from skill_code import move, pick, place, rotate, pull

# video helpers (recording is optional – kept identical to skeleton)
from video import init_video_writers, recording_step, recording_get_observation

# object-position helper supplied by the framework
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------
def fetch_pos(name: str, positions: dict, aliases: list | None = None) -> np.ndarray:
    """
    Return a position array for 'name'.  If not present, fall back to any alias
    contained in the optional 'aliases' list.  Raises KeyError with an
    informative message if nothing matches.
    """
    if name in positions:
        return np.asarray(positions[name])

    if aliases is not None:
        for alt in aliases:
            if alt in positions:
                return np.asarray(positions[alt])

    raise KeyError(
        f"[run_task] Could not find '{name}' (or aliases {aliases}) in "
        f"get_object_positions(). Available keys: {list(positions.keys())}"
    )


def axis_from_vector(vec: np.ndarray) -> str:
    """
    Convert a 3-D direction vector into one of the discrete axis strings accepted
    by `pull()` ('x', '-x', 'y', …).  The dominant component determines the axis
    and the sign determines the direction.
    """
    idx = int(np.argmax(np.abs(vec)))       # 0:x, 1:y, 2:z
    axis_char = 'xyz'[idx]
    sign = '-' if vec[idx] < 0 else ''
    return f"{sign}{axis_char}"


def safe_skill_call(skill_fn, *args, **kwargs):
    """
    Wrap a skill call so that any exception cleanly propagates after printing a
    short message.  This guarantees the finally-clause will still shut the
    environment down.
    """
    try:
        return skill_fn(*args, **kwargs)
    except Exception as exc:
        print(f"[run_task] Exception during '{skill_fn.__name__}': {exc}")
        raise


# ---------------------------------------------------------------------------
# Oracle-plan execution (seven distinct steps from the specification)
# ---------------------------------------------------------------------------
def execute_plan(env, task, positions):
    """Execute the fixed seven-step oracle plan."""
    done = False

    # STEP-1  ───────────────────────────────────────────────────────── rotate
    print("\n--- STEP 1: rotate gripper to ninety_deg ---")
    ninety_quat = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
    obs, reward, done = safe_skill_call(
        rotate, env, task, target_quat=ninety_quat
    )
    if done:
        return

    # STEP-2  ─────────────────────────────────────────────────── move-to-side
    print("\n--- STEP 2: move gripper to bottom_side_pos ---")
    side_pos = fetch_pos(
        'bottom_side_pos',
        positions,
        aliases=['side-pos-bottom']
    )
    obs, reward, done = safe_skill_call(
        move, env, task, target_pos=side_pos
    )
    if done:
        return

    # STEP-3  ────────────────────────────────────────────────── move-to-anchor
    print("\n--- STEP 3: move gripper to bottom_anchor_pos ---")
    anchor_pos = fetch_pos(
        'bottom_anchor_pos',
        positions,
        aliases=['anchor-pos-bottom']
    )
    obs, reward, done = safe_skill_call(
        move, env, task, target_pos=anchor_pos
    )
    if done:
        return

    # STEP-4  ─────────────────────────────────────────────── pick-drawer-handle
    print("\n--- STEP 4: pick (grasp) bottom drawer handle ---")
    obs, reward, done = safe_skill_call(
        pick,
        env, task,
        target_pos=anchor_pos,
        approach_distance=0.10,         # shorter approach to avoid collisions
        approach_axis='z'
    )
    if done:
        return

    # STEP-5  ─────────────────────────────────────────────────────────── pull
    print("\n--- STEP 5: pull drawer outwards ---")
    # Prefer data-driven axis selection if joint information is available
    try:
        joint_pos = fetch_pos(
            'bottom_joint_pos',
            positions,
            aliases=['bottom_joint']      # fallback alias
        )
        pull_dir_guess = axis_from_vector(anchor_pos - joint_pos)
    except KeyError:
        pull_dir_guess = 'x'              # sensible default for RLBench drawers

    obs, reward, done = safe_skill_call(
        pull,
        env, task,
        pull_distance=0.20,               # 20 cm pull
        pull_axis=pull_dir_guess
    )
    if done:
        return

    # STEP-6  ───────────────────────────────────────────────────── pick rubbish
    print("\n--- STEP 6: pick rubbish object from table ---")
    rubbish_pos = None
    try:
        rubbish_pos = fetch_pos('rubbish', positions)
    except KeyError:
        # fall back to ‘item3’, then ‘item2’, then ‘item1’
        for alt in ['item3', 'item2', 'item1']:
            try:
                rubbish_pos = fetch_pos(alt, positions)
                print(f"[run_task] Using alias '{alt}' for rubbish.")
                break
            except KeyError:
                continue
    if rubbish_pos is None:
        raise RuntimeError("Could not locate a rubbish object (rubbish/item1-3).")

    obs, reward, done = safe_skill_call(
        pick,
        env, task,
        target_pos=rubbish_pos,
        approach_distance=0.12,
        approach_axis='z'
    )
    if done:
        return

    # STEP-7  ─────────────────────────────────────────────────── place in bin
    print("\n--- STEP 7: place rubbish into bin ---")
    bin_pos = fetch_pos('bin', positions)
    obs, reward, done = safe_skill_call(
        place,
        env, task,
        target_pos=bin_pos,
        approach_distance=0.12,
        approach_axis='z'
    )

    # ──────────────────────────────────────────────────────────────────────────
    if done:
        print("[run_task] Goal achieved – environment signalled done=True.")
    else:
        print("[run_task] Plan finished – done flag is False (env may accept "
              "more steps).")


# ---------------------------------------------------------------------------
# Main entry-point (environment setup / teardown boiler-plate)
# ---------------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # --- Create environment & task ------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # optional video recording hooks
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # --- Retrieve object positions -------------------------------------------
        positions = get_object_positions()
        print("[run_task] Position keys from environment:", list(positions.keys()))

        # --- Run the oracle plan --------------------------------------------------
        execute_plan(env, task, positions)

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()