import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Explicitly import only the skills we need
from skill_code import pick, place, move, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def run_skeleton_task():
    """Run the task that: 
         1) Rotates the gripper 90 deg,
         2) Opens the bottom drawer,
         3) Places tomato1 & tomato2 on the plate.
    """
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # -----------------------------------------------------------
        # Reset task / set up video capture
        # -----------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)

        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # -----------------------------------------------------------
        # Retrieve all object positions we may need
        # -----------------------------------------------------------
        positions = get_object_positions()

        # Mandatory objects for the oracle plan
        try:
            side_pos       = positions['bottom_side_pos']
            anchor_pos     = positions['bottom_anchor_pos']
            drawer_handle  = positions['bottom_joint_pos']   # handle or joint centre
            tomato1_pos    = positions['tomato1']
            tomato2_pos    = positions['tomato2']
            plate_pos      = positions['plate']
        except KeyError as e:
            raise RuntimeError(f"[Task] Required object missing from get_object_positions(): {e}")

        # Helper to check termination flag after every primitive
        def exec_and_check(skill_fn, *args, **kwargs):
            obs, reward, done = skill_fn(*args, **kwargs)
            if done:
                print("[Task] Task finished early – exiting execution loop.")
                shutdown_environment(env)
                exit(0)
            return obs, reward

        # -----------------------------------------------------------
        # Oracle Plan Execution
        # -----------------------------------------------------------
        print("\n--- Step 1: rotate gripper zero_deg -> ninety_deg ---")
        obs = task.get_observation()
        start_quat = obs.gripper_pose[3:7]  # xyzw
        target_quat = (R.from_quat(start_quat) * R.from_euler('z', 90, degrees=True)).as_quat()
        exec_and_check(
            rotate,
            env, task,
            target_quat=target_quat,
            max_steps=120, threshold=0.05, timeout=10.0
        )

        print("\n--- Step 2: move-to-side (nowhere-pos -> bottom_side_pos) ---")
        exec_and_check(
            move,
            env, task,
            target_pos=side_pos,
            max_steps=120, threshold=0.005, timeout=10.0
        )

        print("\n--- Step 3: move-to-anchor (side -> anchor) ---")
        exec_and_check(
            move,
            env, task,
            target_pos=anchor_pos,
            max_steps=120, threshold=0.005, timeout=10.0
        )

        print("\n--- Step 4: pick drawer handle (simulating pick‑drawer) ---")
        exec_and_check(
            pick,
            env, task,
            target_pos=drawer_handle,
            approach_distance=0.10,
            max_steps=120, threshold=0.005, approach_axis='z', timeout=10.0
        )

        print("\n--- Step 5: pull drawer out ---")
        # Positive X‑direction is an educated guess; adjust if needed
        exec_and_check(
            pull,
            env, task,
            pull_distance=0.20,
            pull_axis='x',
            max_steps=120, threshold=0.005, timeout=10.0
        )

        # ===========================================================
        # Tomato 1
        # ===========================================================
        print("\n--- Step 6: pick tomato1 ---")
        exec_and_check(
            pick,
            env, task,
            target_pos=tomato1_pos,
            approach_distance=0.15, max_steps=120, threshold=0.005,
            approach_axis='z', timeout=10.0
        )

        print("\n--- Step 7: place tomato1 on plate ---")
        exec_and_check(
            place,
            env, task,
            target_pos=plate_pos,
            approach_distance=0.15, max_steps=120, threshold=0.005,
            approach_axis='z', timeout=10.0
        )

        # ===========================================================
        # Tomato 2
        # ===========================================================
        print("\n--- Step 8: pick tomato2 ---")
        exec_and_check(
            pick,
            env, task,
            target_pos=tomato2_pos,
            approach_distance=0.15, max_steps=120, threshold=0.005,
            approach_axis='z', timeout=10.0
        )

        print("\n--- Step 9: place tomato2 on plate ---")
        obs, reward = exec_and_check(
            place,
            env, task,
            target_pos=plate_pos,
            approach_distance=0.15, max_steps=120, threshold=0.005,
            approach_axis='z', timeout=10.0
        )

        print("\n===== Task completed successfully! Reward:", reward, "=====")

    except Exception as exc:
        print("[Task] Exception occurred:", exc)
        raise

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
