import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use provided skills: move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        # Object names from object list
        # [
        # 'bottom_anchor_pos', 'bottom_joint_pos', 'bottom_side_pos',
        # 'middle_anchor_pos', 'middle_joint_pos', 'middle_side_pos',
        # 'top_anchor_pos', 'top_joint_pos', 'top_side_pos',
        # 'item1', 'item2', 'item3', 'plate'
        # ]
        # We'll use 'bottom' drawer for opening, and 'item1', 'item2', 'item3' as tomatoes, 'plate' as destination.

        # Defensive: check all required keys exist
        required_keys = [
            'bottom_side_pos', 'bottom_anchor_pos', 'item1', 'item2', 'item3', 'plate'
        ]
        for k in required_keys:
            if k not in positions:
                raise RuntimeError(f"Required object position '{k}' not found in get_object_positions().")

        # Plan steps (from specification):
        # 1. move (to bottom_side_pos)
        # 2. rotate (to 90 deg)
        # 3. move (to bottom_anchor_pos)
        # 4. pick (drawer handle)
        # 5. pull (open drawer)
        # 6. move (to item1)
        # 7. rotate (to 0 deg)
        # 8. move (to item1 above)
        # 9. pick (item1)
        # 10. place (on plate)
        # 11. move (to item2)
        # 12. pick (item2)
        # 13. place (on plate)
        #
        # We'll do the same for item3 after item2.

        # Step 1: move to bottom_side_pos (side of bottom drawer)
        print("[Plan] Step 1: Move to bottom_side_pos")
        obs, reward, done = move(
            env, task, positions['bottom_side_pos']
        )
        if done:
            print("[Task] Task ended after move to bottom_side_pos!")
            return

        # Step 2: rotate gripper to 90 deg (for drawer grasp)
        print("[Plan] Step 2: Rotate gripper to 90 deg")
        # We'll assume 90 deg is [x, y, z, w] quaternion for 90 deg about z axis
        # If positions['bottom_side_pos'] is at the drawer, we rotate in place
        # Let's try to get the quaternion for 90 deg about z
        from scipy.spatial.transform import Rotation as R
        ninety_deg_quat = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
        obs, reward, done = rotate(
            env, task, ninety_deg_quat
        )
        if done:
            print("[Task] Task ended after rotate to 90 deg!")
            return

        # Step 3: move to bottom_anchor_pos (drawer handle)
        print("[Plan] Step 3: Move to bottom_anchor_pos")
        obs, reward, done = move(
            env, task, positions['bottom_anchor_pos']
        )
        if done:
            print("[Task] Task ended after move to bottom_anchor_pos!")
            return

        # Step 4: pick the drawer handle (simulate pick-drawer)
        print("[Plan] Step 4: Pick drawer handle (bottom)")
        # For pick, we use the anchor position as the target
        obs, reward, done = pick(
            env, task, positions['bottom_anchor_pos'],
            approach_distance=0.10,  # slightly less to avoid collision
            approach_axis='z'
        )
        if done:
            print("[Task] Task ended after picking drawer handle!")
            return

        # Step 5: pull the drawer open
        print("[Plan] Step 5: Pull drawer (bottom)")
        # Pull along x axis (assume drawer opens along +x, adjust if needed)
        obs, reward, done = pull(
            env, task, pull_distance=0.18, pull_axis='x'
        )
        if done:
            print("[Task] Task ended after pulling drawer!")
            return

        # Step 6: move to item1 position (inside drawer)
        print("[Plan] Step 6: Move to item1")
        obs, reward, done = move(
            env, task, positions['item1']
        )
        if done:
            print("[Task] Task ended after move to item1!")
            return

        # Step 7: rotate gripper to 0 deg (for vertical pick)
        print("[Plan] Step 7: Rotate gripper to 0 deg")
        zero_deg_quat = R.from_euler('z', 0, degrees=True).as_quat()
        obs, reward, done = rotate(
            env, task, zero_deg_quat
        )
        if done:
            print("[Task] Task ended after rotate to 0 deg!")
            return

        # Step 8: move above item1 (for top-down pick)
        print("[Plan] Step 8: Move above item1")
        item1_above = np.array(positions['item1']) + np.array([0, 0, 0.08])
        obs, reward, done = move(
            env, task, item1_above
        )
        if done:
            print("[Task] Task ended after move above item1!")
            return

        # Step 9: pick item1
        print("[Plan] Step 9: Pick item1")
        obs, reward, done = pick(
            env, task, positions['item1'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after picking item1!")
            return

        # Step 10: place item1 on plate
        print("[Plan] Step 10: Place item1 on plate")
        plate_above = np.array(positions['plate']) + np.array([0, 0, 0.08])
        obs, reward, done = place(
            env, task, positions['plate'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after placing item1!")
            return

        # Step 11: move to item2
        print("[Plan] Step 11: Move to item2")
        obs, reward, done = move(
            env, task, positions['item2']
        )
        if done:
            print("[Task] Task ended after move to item2!")
            return

        # Step 12: pick item2
        print("[Plan] Step 12: Pick item2")
        obs, reward, done = pick(
            env, task, positions['item2'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after picking item2!")
            return

        # Step 13: place item2 on plate
        print("[Plan] Step 13: Place item2 on plate")
        obs, reward, done = place(
            env, task, positions['plate'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after placing item2!")
            return

        # (Extra) Step 14: move to item3
        print("[Plan] Step 14: Move to item3")
        obs, reward, done = move(
            env, task, positions['item3']
        )
        if done:
            print("[Task] Task ended after move to item3!")
            return

        # (Extra) Step 15: pick item3
        print("[Plan] Step 15: Pick item3")
        obs, reward, done = pick(
            env, task, positions['item3'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after picking item3!")
            return

        # (Extra) Step 16: place item3 on plate
        print("[Plan] Step 16: Place item3 on plate")
        obs, reward, done = place(
            env, task, positions['plate'],
            approach_distance=0.10,
            approach_axis='-z'
        )
        if done:
            print("[Task] Task ended after placing item3!")
            return

        print("[Task] All tomatoes placed on plate. Task completed successfully!")

    except Exception as e:
        print(f"[Task] Exception occurred: {e}")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()