import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *
from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation, now specialized to open the bottom drawer.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # Initialize video writers for capturing your simulation (optional)
        init_video_writers(obs)

        # Wrap the task steps for recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # === Retrieve Object Positions and Orientations ===
        positions = get_object_positions()

        # Extract angle quaternions
        zero_quat = positions.get('zero_deg')
        ninety_quat = positions.get('ninety_deg')
        if zero_quat is None or ninety_quat is None:
            raise KeyError("Angle quaternions 'zero_deg' or 'ninety_deg' not found in positions.")

        # Extract handle positions for bottom drawer
        side_pos_bottom = positions.get('side-pos-bottom')
        anchor_pos_bottom = positions.get('anchor-pos-bottom')
        if side_pos_bottom is None or anchor_pos_bottom is None:
            raise KeyError("Required handle positions for bottom drawer not found in positions.")

        # Extract joint position for computing pull vector
        joint_pos_bottom = positions.get('bottom_joint_pos')
        if joint_pos_bottom is None:
            raise KeyError("Joint position 'bottom_joint_pos' not found in positions.")

        # === Execute Oracle Plan ===

        # Frozen Code Start
        obs, reward, done = rotate(env, task, target_quat=ninety_quat)
obs, reward, done = move(env, task, target_pos=np.array(side_pos_bottom)
obs, reward, done = move(env, task, target_pos=np.array(anchor_pos_bottom)
        # Frozen Code End

        # Step 4: Pick (grasp) the drawer handle
        print("[Plan Step 4] pick up drawer handle (bottom)")
        obs, reward, done = pick(
            env,
            task,
            target_pos=np.array(anchor_pos_bottom),
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Ended during pick. Reward:", reward)
            return

        # Step 5: Pull the drawer open
        anchor = np.array(anchor_pos_bottom)
        joint = np.array(joint_pos_bottom)
        pull_vec = joint - anchor
        abs_vec = np.abs(pull_vec)
        axis_index = int(np.argmax(abs_vec))
        axis_map = {0: 'x', 1: 'y', 2: 'z'}
        if pull_vec[axis_index] >= 0:
            axis = axis_map[axis_index]
        else:
            axis = '-' + axis_map[axis_index]
        pull_distance = float(pull_vec[axis_index])
        print(f"[Plan Step 5] pull drawer along {axis} by {pull_distance:.3f} m")
        obs, reward, done = pull(
            env,
            task,
            pull_distance=pull_distance,
            pull_axis=axis,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Ended during pull. Reward:", reward)
        else:
            print("[Task] Completed pull. Drawer should now be open. Reward:", reward)

    except Exception as e:
        print(f"[Task] Exception during execution: {e}")
    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()