import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # Initialize video writers if needed
        init_video_writers(obs)

        # Wrap step and get_observation for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Retrieve the static object positions from the helper module
        positions = get_object_positions()
        print("[Task] Retrieved object positions:", positions.keys())

        # Define the drawer levels we will process
        drawer_levels = ['bottom', 'middle', 'top']

        # First align the gripper orientation for pulling the drawer handle
        print("[Task] Aligning gripper orientation to ninety degrees")
        ninety_deg_quat = [0.0, 0.0, 0.7071, 0.7071]
        try:
            obs, reward, done = rotate(env, task, ninety_deg_quat)
        except Exception as e:
            print(f"[Error] rotate failed: {e}")
            return
        if done:
            print("[Task] Task ended during initial rotation")
            return

        # Loop over each drawer level and open it
        for level in drawer_levels:
            print(f"[Task] Processing {level} drawer")
            side_key = f"{level}_side_pos"
            anchor_key = f"{level}_anchor_pos"
            joint_key = f"{level}_joint_pos"

            # Check that all required positions exist
            if not all(k in positions for k in (side_key, anchor_key, joint_key)):
                print(f"[Warning] Missing one of {side_key}, {anchor_key}, {joint_key}; skipping {level} drawer")
                continue

            # 1) Move to the side-of-handle position
            try:
                side_pos = positions[side_key]
                print(f"[Task] Moving to side position of {level} drawer at {side_pos}")
                obs, reward, done = move(env, task, side_pos)
            except Exception as e:
                print(f"[Error] move to side failed: {e}")
                continue
            if done:
                print("[Task] Task ended unexpectedly after moving to side")
                return

            # 2) Move to the anchor (grasp) position
            try:
                anchor_pos = positions[anchor_key]
                print(f"[Task] Moving to anchor position of {level} drawer at {anchor_pos}")
                obs, reward, done = move(env, task, anchor_pos)
            except Exception as e:
                print(f"[Error] move to anchor failed: {e}")
                continue
            if done:
                print("[Task] Task ended unexpectedly after moving to anchor")
                return

            # 3) Pick the drawer handle (joint) to prepare for pulling
            try:
                joint_pos = positions[joint_key]
                print(f"[Task] Picking handle of {level} drawer at {joint_pos}")
                obs, reward, done = pick(
                    env,
                    task,
                    target_pos=joint_pos,
                    approach_distance=0.05,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=10.0
                )
            except Exception as e:
                print(f"[Error] pick handle failed: {e}")
                continue
            if done:
                print("[Task] Task ended during pick of handle")
                return

            # 4) Pull the handle to open the drawer
            try:
                print(f"[Task] Pulling {level} drawer handle")
                obs, reward, done = pull(env, task)
            except Exception as e:
                print(f"[Error] pull failed: {e}")
                continue
            if done:
                print(f"[Task] Task ended during pull of {level} drawer")
                return

            print(f"[Task] {level.capitalize()} drawer should now be open")

        print("[Task] All drawers processed")
    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()