import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()

        # Define quaternions for zero_deg and ninety_deg rotations (xyzw)
        zero_quat = np.array([0.0, 0.0, 0.0, 1.0])
        ninety_quat = np.array([0.0, 0.0, 0.70710678, 0.70710678])

        # Oracle plan: for each drawer in order bottom, middle, top
        drawers = ['bottom', 'middle', 'top']
        done = False
        reward = None

        for idx, drawer in enumerate(drawers):
            # Step 1,6,11: compute target orientation
            target_quat = ninety_quat if idx % 2 == 0 else zero_quat
            print(f"[Task] Step {1 + idx*5}: rotate to {'ninety_deg' if idx % 2 == 0 else 'zero_deg'} for {drawer} drawer")

            # Step 2,7,12: retrieve side position
            side_key = f'{drawer}_side_pos'
            side_pos = positions.get(side_key)
            print(f"[Task] Step {2 + idx*5}: move to side position of {drawer} drawer ->", side_pos)

            obs, reward, done = rotate(env, task, target_quat)
            obs, reward, done = move(env, task, side_pos)


            if done:
                print(f"[Task] Task ended early after move-to-side for {drawer}. Reward: {reward}")
                break

            # Step 3,8,13: move to anchor position
            anchor_key = f'{drawer}_anchor_pos'
            anchor_pos = positions.get(anchor_key)
            print(f"[Task] Step {3 + idx*5}: move to anchor position of {drawer} drawer ->", anchor_pos)
            obs, reward, done = move(env, task, anchor_pos)
            if done:
                print(f"[Task] Task ended early after move-to-anchor for {drawer}. Reward: {reward}")
                break

            # Step 4,9,14: pick at anchor
            print(f"[Task] Step {4 + idx*5}: pick {drawer} drawer handle at", anchor_pos)
            obs, reward, done = pick(env, task, anchor_pos)
            if done:
                print(f"[Task] Task ended early after pick for {drawer}. Reward: {reward}")
                break

            # Step 5,10,15: pull drawer open
            joint_key = f'{drawer}_joint_pos'
            joint_pos = positions.get(joint_key)
            # compute pull distance and axis
            delta = joint_pos - anchor_pos
            pull_axis = 'x' if delta[0] > 0 else '-x'
            pull_distance = abs(delta[0])
            print(f"[Task] Step {5 + idx*5}: pull {drawer} drawer by {pull_distance:.3f} along axis {pull_axis}")
            obs, reward, done = pull(env, task, pull_distance, pull_axis)
            if done:
                print(f"[Task] {drawer} drawer opened and task ended. Reward: {reward}")
                break

        if not done:
            print("[Task] Completed plan for all drawers. Final reward:", reward)

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()