import numpy as np
import math
import traceback

from pyrep.objects.shape import Shape                  # needed by skill_code even if not used here
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Low-level skills that the simulator already provides
from skill_code import move, rotate, pick, pull, place

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
#  Helper utilities
# ---------------------------------------------------------------------------
def quaternion_about_z(angle_rad: float) -> np.ndarray:
    """Return a quaternion (xyzw) for a rotation of ‘angle_rad’ about world-Z."""
    return np.asarray([0.0,
                       0.0,
                       math.sin(angle_rad * 0.5),
                       math.cos(angle_rad * 0.5)],
                      dtype=np.float32)


def safe_call(skill_fn, *args, **kwargs):
    """Run a skill while guaranteeing environment shutdown on error."""
    try:
        return skill_fn(*args, **kwargs)
    except Exception as exc:
        print(f"[ERROR] Exception in skill {skill_fn.__name__}: {exc}")
        traceback.print_exc()
        raise


# ---------------------------------------------------------------------------
#  Main routine – executes the oracle plan in the specification (steps 1-7)
# ---------------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------ #
    # 0)  Simulator set-up
    # ------------------------------------------------------------------ #
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        #  Optional video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------ #
        # 1)  Query object positions that the oracle plan needs
        # ------------------------------------------------------------------ #
        positions = get_object_positions()

        # Mandatory keys for the oracle plan
        required = {
            'bottom_side_pos'  : 'side-pos-bottom',
            'bottom_anchor_pos': 'anchor-pos-bottom',
            'item3'            : 'rubbish',
            'bin'              : 'trash-bin'
        }
        for key in required:
            if key not in positions:
                raise KeyError(f"[Setup] get_object_positions() did not return “{key}”.")

        side_pos   = np.asarray(positions['bottom_side_pos'],   dtype=np.float32)
        anchor_pos = np.asarray(positions['bottom_anchor_pos'], dtype=np.float32)
        rubbish_pos= np.asarray(positions['item3'],             dtype=np.float32)
        bin_pos    = np.asarray(positions['bin'],               dtype=np.float32)

        # ------------------------------------------------------------------ #
        # 2)  Oracle plan (specification steps 1 → 7)
        # ------------------------------------------------------------------ #
        done = False

        # Step-1  rotate(gripper, zero_deg, ninety_deg)
        print("\n--- [Step-1] Rotate gripper +90 deg about Z ---")
        target_quat = quaternion_about_z(math.pi/2.0)
        obs, reward, done = safe_call(rotate,
                                      env, task,
                                      target_quat=target_quat)
        if done:
            print("[Task] Terminated after Step-1.")
            return

        # Step-2  move-to-side
        print("\n--- [Step-2] Move to drawer side position ---")
        obs, reward, done = safe_call(move,
                                      env, task,
                                      target_pos=side_pos)
        if done:
            print("[Task] Terminated after Step-2.")
            return

        # Step-3  move-to-anchor
        print("\n--- [Step-3] Move to drawer anchor (handle) ---")
        obs, reward, done = safe_call(move,
                                      env, task,
                                      target_pos=anchor_pos)
        if done:
            print("[Task] Terminated after Step-3.")
            return

        # Step-4  pick-drawer   (use generic pick to grasp the handle)
        print("\n--- [Step-4] Grasp drawer handle ---")
        obs, reward, done = safe_call(pick,
                                      env, task,
                                      target_pos=anchor_pos,
                                      approach_distance=0.10,
                                      approach_axis='z')
        if done:
            print("[Task] Terminated after Step-4.")
            return

        # Step-5  pull
        print("\n--- [Step-5] Pull drawer open ---")
        obs, reward, done = safe_call(pull,
                                      env, task,
                                      pull_distance=0.22,     # pull ~22 cm
                                      pull_axis='x')          # drawer assumed +X
        if done:
            print("[Task] Terminated after Step-5.")
            return

        # (Optional) Release the handle so that the gripper can reopen
        print("\n--- [Step-5b] Release drawer handle ---")
        obs, reward, done = safe_call(place,
                                      env, task,
                                      target_pos=anchor_pos,   # simply open gripper in place
                                      approach_distance=0.00,
                                      approach_axis='z')
        if done:
            print("[Task] Terminated after Step-5b.")
            return

        # Step-6  pick rubbish
        print("\n--- [Step-6] Move above rubbish ---")
        obs, reward, done = safe_call(move,
                                      env, task,
                                      target_pos=rubbish_pos + np.array([0.0, 0.0, 0.10]))
        if done:
            print("[Task] Terminated after Step-6 (move above rubbish).")
            return

        print("\n--- [Step-6] Pick rubbish ---")
        obs, reward, done = safe_call(pick,
                                      env, task,
                                      target_pos=rubbish_pos,
                                      approach_distance=0.10,
                                      approach_axis='z')
        if done:
            print("[Task] Terminated after Step-6 (pick rubbish).")
            return

        # Step-7  place rubbish in bin
        print("\n--- [Step-7] Move above bin ---")
        obs, reward, done = safe_call(move,
                                      env, task,
                                      target_pos=bin_pos + np.array([0.0, 0.0, 0.10]))
        if done:
            print("[Task] Terminated after Step-7 (move above bin).")
            return

        print("\n--- [Step-7] Drop rubbish in bin ---")
        obs, reward, done = safe_call(place,
                                      env, task,
                                      target_pos=bin_pos,
                                      approach_distance=0.08,
                                      approach_axis='z')
        if done:
            print("[Task] Terminated after Step-7 (place).")
            return

        print("\n===== Oracle plan finished – Goal achieved! =====")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


# ---------------------------------------------------------------------------
#  Entrypoint
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    run_skeleton_task()