import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import move, pick, place, rotate, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _safe_lookup(pos_dict, key):
    """Utility: fetch a position from the dictionary and raise a clear error if missing."""
    if key not in pos_dict:
        raise KeyError(f"[Oracle-Plan] '{key}' was not found in object_positions!")
    return np.asarray(pos_dict[key], dtype=np.float32)


def run_oracle_plan():
    """Run the hand-crafted oracle plan that satisfies the task goals."""
    print("===========  ORACLE PLAN: START  ===========")
    env, task = setup_environment()

    try:
        # -------------------------------------------------------------
        #  Reset the environment & wrap the task for optional recording
        # -------------------------------------------------------------
        _, obs = task.reset()
        init_video_writers(obs)

        #  Replace step / get_observation so every action is recorded
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # -------------------------------------------------------------
        #              Retrieve all relevant object poses
        # -------------------------------------------------------------
        pos = get_object_positions()

        side_pos_bottom   = _safe_lookup(pos, 'bottom_side_pos')
        anchor_pos_bottom = _safe_lookup(pos, 'bottom_anchor_pos')
        tomato1_pos       = _safe_lookup(pos, 'tomato1')
        tomato2_pos       = _safe_lookup(pos, 'tomato2')
        plate_pos         = _safe_lookup(pos, 'plate')

        # -------------------------------------------------------------
        #                    Execute Oracle Action List
        # -------------------------------------------------------------
        # STEP-1  : rotate(gripper, zero_deg, ninety_deg)
        print("\n[STEP-1] rotate gripper to 90° about world-Z")
        current_quat = task.get_observation().gripper_pose[3:7]
        target_quat  = R.from_euler('z', 90, degrees=True).as_quat()   # xyzw
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Oracle-Plan] Episode ended during STEP-1"); return

        # STEP-2  : move-to-side (bottom drawer)
        print("\n[STEP-2] move to side position of bottom drawer")
        obs, reward, done = move(env, task, target_pos=side_pos_bottom)
        if done:
            print("[Oracle-Plan] Episode ended during STEP-2"); return

        # STEP-3  : move-to-anchor (bottom drawer)
        print("\n[STEP-3] move to anchor position of bottom drawer")
        obs, reward, done = move(env, task, target_pos=anchor_pos_bottom)
        if done:
            print("[Oracle-Plan] Episode ended during STEP-3"); return

        # STEP-4  : pick-drawer  (close gripper on handle)
        print("\n[STEP-4] grasp the drawer handle")
        obs, reward, done = pick(env, task,
                                 target_pos=anchor_pos_bottom,
                                 approach_distance=0.05,   # closer approach for handle
                                 max_steps=80,
                                 threshold=0.005,
                                 approach_axis='z')
        if done:
            print("[Oracle-Plan] Episode ended during STEP-4"); return

        # STEP-5  : pull drawer open
        print("\n[STEP-5] pull the drawer straight out")
        obs, reward, done = pull(env, task,
                                 pull_distance=0.18,        # pull ~18 cm
                                 pull_axis='x')             # assuming drawer opens +X
        if done:
            print("[Oracle-Plan] Episode ended during STEP-5"); return

        # STEP-6  : pick tomato-1
        print("\n[STEP-6] pick tomato1 from table")
        obs, reward, done = pick(env, task,
                                 target_pos=tomato1_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Oracle-Plan] Episode ended during STEP-6"); return

        # STEP-7  : place tomato-1 on plate
        print("\n[STEP-7] place tomato1 on plate")
        obs, reward, done = place(env, task,
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Oracle-Plan] Episode ended during STEP-7"); return

        # STEP-8  : pick tomato-2
        print("\n[STEP-8] pick tomato2 from table")
        obs, reward, done = pick(env, task,
                                 target_pos=tomato2_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Oracle-Plan] Episode ended during STEP-8"); return

        # STEP-9  : place tomato-2 on plate
        print("\n[STEP-9] place tomato2 on plate")
        obs, reward, done = place(env, task,
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Oracle-Plan] Episode ended during STEP-9"); return

        print("\n===========  ORACLE PLAN: SUCCESS  ===========")

    finally:
        shutdown_environment(env)


if __name__ == "__main__":
    run_oracle_plan()