import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def run_skeleton_task():
    '''Task: Pull open any unlocked drawer, then drop the 2 tomatoes onto the plate.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        # Object names from object list
        # [
        # 'bottom_anchor_pos', 'bottom_joint_pos', 'bottom_side_pos',
        # 'middle_anchor_pos', 'middle_joint_pos', 'middle_side_pos',
        # 'top_anchor_pos', 'top_joint_pos', 'top_side_pos',
        # 'item1', 'item2', 'item3', 'plate'
        # ]
        # We'll treat 'item1' and 'item2' as the two tomatoes, and 'plate' as the destination.

        # --- Step 1: Move gripper to the side position of an unlocked drawer ---
        # We'll try bottom drawer first (bottom_side_pos, bottom_anchor_pos, etc.)
        # If not available, fallback to middle or top.
        # For this code, we assume at least one drawer is not locked.

        # Choose drawer: bottom
        drawer_side_pos = positions.get('bottom_side_pos', None)
        drawer_anchor_pos = positions.get('bottom_anchor_pos', None)
        drawer_joint_pos = positions.get('bottom_joint_pos', None)
        gripper_approach_axis = 'z'  # Approach from above

        if drawer_side_pos is None or drawer_anchor_pos is None:
            # Fallback to middle drawer
            drawer_side_pos = positions.get('middle_side_pos', None)
            drawer_anchor_pos = positions.get('middle_anchor_pos', None)
            drawer_joint_pos = positions.get('middle_joint_pos', None)
            if drawer_side_pos is None or drawer_anchor_pos is None:
                # Fallback to top drawer
                drawer_side_pos = positions.get('top_side_pos', None)
                drawer_anchor_pos = positions.get('top_anchor_pos', None)
                drawer_joint_pos = positions.get('top_joint_pos', None)
                if drawer_side_pos is None or drawer_anchor_pos is None:
                    raise RuntimeError("No drawer positions found in object_positions.")

        # Step 1: Move to side position of the drawer
        print("[Task] Step 1: Move to drawer side position:", drawer_side_pos)
        obs, reward, done = move(
            env,
            task,
            target_pos=drawer_side_pos,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after move to drawer side position!")
            return

        # Step 2: Rotate gripper to 90 degrees (vertical for grasping handle)
        print("[Task] Step 2: Rotate gripper to 90 degrees")
        # We'll use a quaternion for 90 deg rotation about z axis
        # RLBench uses [x, y, z, w] (xyzw)
        from scipy.spatial.transform import Rotation as R
        quat_ninety_deg = R.from_euler('z', 90, degrees=True).as_quat()
        obs, reward, done = rotate(
            env,
            task,
            target_quat=quat_ninety_deg,
            max_steps=100,
            threshold=0.05,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after rotate!")
            return

        # Step 3: Pick the drawer handle (move to anchor position and close gripper)
        print("[Task] Step 3: Pick drawer handle at anchor position:", drawer_anchor_pos)
        obs, reward, done = pick(
            env,
            task,
            target_pos=drawer_anchor_pos,
            approach_distance=0.10,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after picking drawer handle!")
            return

        # Step 4: Pull the drawer open (pull along x axis, e.g. 0.15m)
        print("[Task] Step 4: Pull drawer open")
        pull_distance = 0.15  # meters
        pull_axis = 'x'
        obs, reward, done = pull(
            env,
            task,
            pull_distance=pull_distance,
            pull_axis=pull_axis,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after pulling drawer!")
            return

        # Step 5: Move to item1 (tomato1) position
        item1_pos = positions.get('item1', None)
        if item1_pos is None:
            raise RuntimeError("item1 position not found in object_positions.")
        print("[Task] Step 5: Move to item1 (tomato1) position:", item1_pos)
        obs, reward, done = move(
            env,
            task,
            target_pos=item1_pos,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after move to item1!")
            return

        # Step 6: Rotate gripper to 0 degrees (flat for picking up tomato)
        print("[Task] Step 6: Rotate gripper to 0 degrees")
        quat_zero_deg = R.from_euler('z', 0, degrees=True).as_quat()
        obs, reward, done = rotate(
            env,
            task,
            target_quat=quat_zero_deg,
            max_steps=100,
            threshold=0.05,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after rotate to 0 deg!")
            return

        # Step 7: Pick item1 (tomato1)
        print("[Task] Step 7: Pick item1 (tomato1)")
        obs, reward, done = pick(
            env,
            task,
            target_pos=item1_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after picking item1!")
            return

        # Step 8: Place item1 on plate
        plate_pos = positions.get('plate', None)
        if plate_pos is None:
            raise RuntimeError("plate position not found in object_positions.")
        print("[Task] Step 8: Place item1 on plate:", plate_pos)
        obs, reward, done = place(
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after placing item1!")
            return

        # Step 9: Move to item2 (tomato2) position
        item2_pos = positions.get('item2', None)
        if item2_pos is None:
            raise RuntimeError("item2 position not found in object_positions.")
        print("[Task] Step 9: Move to item2 (tomato2) position:", item2_pos)
        obs, reward, done = move(
            env,
            task,
            target_pos=item2_pos,
            max_steps=100,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after move to item2!")
            return

        # Step 10: Pick item2 (tomato2)
        print("[Task] Step 10: Pick item2 (tomato2)")
        obs, reward, done = pick(
            env,
            task,
            target_pos=item2_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Task ended after picking item2!")
            return

        # Step 11: Place item2 on plate
        print("[Task] Step 11: Place item2 on plate:", plate_pos)
        obs, reward, done = place(
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Task completed successfully! Reward:", reward)
        else:
            print("[Task] Task not completed yet (done=False).")

    except Exception as e:
        print("[Task] Exception occurred:", str(e))
    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()