import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions
from scipy.spatial.transform import Rotation as R

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # Initialize video writers (optional)
        init_video_writers(obs)

        # Wrap step and get_observation to record
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Retrieve positions of all relevant objects
        positions = get_object_positions()

        # === Oracle Plan Execution ===
        try:
            # Step 1: Rotate gripper from zero_deg to ninety_deg
            print("[Task] Step 1: rotate gripper to 90° around Z")
            ninety_quat = R.from_euler('z', 90, degrees=True).as_quat()
            obs, reward, done = rotate(env, task, ninety_quat)
            if done:
                print("[Task] Terminated after rotate")
                return

            # Step 2: Move to the bottom drawer side position
            print("[Task] Step 2: move to bottom_side_pos")
            side_pos = np.array(positions['bottom_side_pos'])

            # Step 3: Move to the bottom drawer anchor position
            print("[Task] Step 3: move to bottom_anchor_pos")
            anchor_pos = np.array(positions['bottom_anchor_pos'])

    [Frozen Code Start]
    obs, reward, done = rotate(env, task, ninety_quat)
obs, reward, done = move(env, task, side_pos)
obs, reward, done = move(env, task, anchor_pos)
    [Frozen Code End]
            if done:
                print("[Task] Terminated after move-to-anchor")
                return

            # Step 4: Grasp the drawer handle (pick-drawer)
            print("[Task] Step 4: pick-drawer at anchor_pos")
            obs, reward, done = pick(env, task, anchor_pos)
            if done:
                print("[Task] Terminated after pick-drawer")
                return

            # Step 5: Pull the drawer open
            print("[Task] Step 5: pull drawer open")
            joint_pos = np.array(positions['bottom_joint_pos'])
            vec = joint_pos - anchor_pos
            pull_dist = np.linalg.norm(vec)
            # Determine primary axis for pull
            idx = np.argmax(np.abs(vec))
            if idx == 0:
                pull_axis = 'x' if vec[0] > 0 else '-x'
            elif idx == 1:
                pull_axis = 'y' if vec[1] > 0 else '-y'
            else:
                pull_axis = 'z' if vec[2] > 0 else '-z'
            obs, reward, done = pull(env, task, pull_dist, pull_axis=pull_axis)
            if done:
                print("[Task] Terminated after pull")
                return

            # Step 6: Pick up the first tomato (item1)
            print("[Task] Step 6: pick tomato1 (item1) from table")
            item1_pos = np.array(positions['item1'])
            obs, reward, done = pick(env, task, item1_pos)
            if done:
                print("[Task] Terminated after pick item1")
                return

            # Step 7: Place the first tomato onto the plate
            print("[Task] Step 7: place tomato1 on plate")
            plate_pos = np.array(positions['plate'])
            obs, reward, done = place(env, task, plate_pos)
            if done:
                print("[Task] Terminated after place item1")
                return

            # Step 8: Pick up the second tomato (item2)
            print("[Task] Step 8: pick tomato2 (item2) from table")
            item2_pos = np.array(positions['item2'])
            obs, reward, done = pick(env, task, item2_pos)
            if done:
                print("[Task] Terminated after pick item2")
                return

            # Step 9: Place the second tomato onto the plate
            print("[Task] Step 9: place tomato2 on plate")
            obs, reward, done = place(env, task, plate_pos)
            if done:
                print("[Task] Completed: both tomatoes placed")
            else:
                print("[Task] Executed all steps; done flag is", done)

        except Exception as e:
            print("[Task] Exception during execution:", e)
            raise

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()