import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # Use provided skills: move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def run_skeleton_task():
    '''Task: Open a drawer fully, then pick up all the tomatoes and leave them on the plate.'''
    print("===== Starting Skeleton Task =====")
    
    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        # Object name mapping (from object list and observation)
        # Drawer: choose 'bottom' drawer for this plan
        # Gripper: assumed to be 'gripper'
        # Items: 'item1', 'item2', 'item3' (tomatoes)
        # Plate: 'plate'
        # Drawer positions: anchor, side, joint
        # For bottom drawer:
        #   - anchor: 'bottom_anchor_pos'
        #   - side:   'bottom_side_pos'
        #   - joint:  'bottom_joint_pos'
        #   - drawer name: 'bottom'
        #   - gripper: 'gripper'
        #   - angles: zero_deg, ninety_deg (assume 0 and 90 degrees in radians or as quaternions)
        #   - For rotation, we need to know the quaternion for 90 degrees about the appropriate axis

        # Defensive: check all required keys
        required_keys = [
            'bottom_anchor_pos', 'bottom_side_pos', 'bottom_joint_pos',
            'middle_anchor_pos', 'middle_side_pos', 'middle_joint_pos',
            'top_anchor_pos', 'top_side_pos', 'top_joint_pos',
            'item1', 'item2', 'item3', 'plate'
        ]
        for k in required_keys:
            if k not in positions:
                raise RuntimeError(f"Missing object position for: {k}")

        # Drawer and gripper names
        drawer_name = 'bottom'
        gripper_name = 'gripper'
        anchor_pos = positions['bottom_anchor_pos']
        side_pos = positions['bottom_side_pos']
        joint_pos = positions['bottom_joint_pos']
        # For the plan, we will move to side, rotate, move to anchor, pick-drawer, pull, etc.

        # Items and plate
        tomato_names = ['item1', 'item2', 'item3']
        tomato_positions = [positions[name] for name in tomato_names]
        plate_pos = positions['plate']

        # For rotation, we need the quaternion for 90 degrees about the appropriate axis.
        # Let's assume the gripper rotates about z or y axis. We'll use scipy to generate the quaternion.
        from scipy.spatial.transform import Rotation as R
        # Get current orientation
        obs = task.get_observation()
        current_quat = obs.gripper_pose[3:7]
        # For 90 degrees about z axis:
        ninety_deg_quat = R.from_euler('z', 90, degrees=True).as_quat()
        # For zero degrees (identity):
        zero_deg_quat = R.from_euler('z', 0, degrees=True).as_quat()

        # === PLAN EXECUTION ===
        # Plan steps (from specification):
        # 1. move (to side position of bottom drawer)
        # 2. rotate (to 90 deg)
        # 3. move (to anchor position of bottom drawer)
        # 4. pick (the drawer handle at anchor)
        # 5. pull (the drawer open)
        # 6. move (to above first tomato)
        # 7. rotate (to zero deg, for picking)
        # 8. move (to above first tomato, if needed)
        # 9. pick (first tomato)
        # 10. move (to above plate)
        # 11. place (on plate)
        # 12. move (to above second tomato)
        # 13. pick (second tomato)
        # 14. move (to above plate)
        # 15. place (on plate)
        # Repeat for third tomato

        # Step 1: move to side position of bottom drawer
        print("[Step 1] Move to side position of bottom drawer")
        obs, reward, done = move(env, task, target_pos=side_pos)
        if done:
            print("[Task] Task ended after move to side position!")
            return

        # Step 2: rotate gripper to 90 degrees (for drawer grasp)
        print("[Step 2] Rotate gripper to 90 degrees")
        obs, reward, done = rotate(env, task, target_quat=ninety_deg_quat)
        if done:
            print("[Task] Task ended after rotate!")
            return

        # Step 3: move to anchor position of bottom drawer
        print("[Step 3] Move to anchor position of bottom drawer")
        obs, reward, done = move(env, task, target_pos=anchor_pos)
        if done:
            print("[Task] Task ended after move to anchor!")
            return

        # Step 4: pick the drawer handle (at anchor)
        print("[Step 4] Pick the drawer handle at anchor position")
        # For picking the drawer, approach along -y or -z as appropriate
        obs, reward, done = pick(env, task, target_pos=anchor_pos, approach_distance=0.10, approach_axis='y')
        if done:
            print("[Task] Task ended after picking drawer handle!")
            return

        # Step 5: pull the drawer open (along x axis, e.g. +x)
        print("[Step 5] Pull the drawer open")
        # Estimate pull distance (e.g., 0.20m)
        pull_distance = 0.20
        obs, reward, done = pull(env, task, pull_distance=pull_distance, pull_axis='x')
        if done:
            print("[Task] Task ended after pulling drawer!")
            return

        # Step 6: move to above first tomato (item1)
        print("[Step 6] Move to above first tomato (item1)")
        tomato1_pos = positions['item1']
        # Move above tomato1 (add offset in z)
        above_tomato1 = np.array(tomato1_pos) + np.array([0, 0, 0.10])
        obs, reward, done = move(env, task, target_pos=above_tomato1)
        if done:
            print("[Task] Task ended after move to tomato1!")
            return

        # Step 7: rotate gripper to zero degrees (for picking)
        print("[Step 7] Rotate gripper to zero degrees")
        obs, reward, done = rotate(env, task, target_quat=zero_deg_quat)
        if done:
            print("[Task] Task ended after rotate to zero!")
            return

        # Step 8: move to above first tomato again (ensure alignment)
        print("[Step 8] Move to above first tomato (item1) again")
        obs, reward, done = move(env, task, target_pos=above_tomato1)
        if done:
            print("[Task] Task ended after move to tomato1 (again)!")
            return

        # Step 9: pick first tomato
        print("[Step 9] Pick first tomato (item1)")
        obs, reward, done = pick(env, task, target_pos=tomato1_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after picking tomato1!")
            return

        # Step 10: move to above plate
        print("[Step 10] Move to above plate")
        above_plate = np.array(plate_pos) + np.array([0, 0, 0.10])
        obs, reward, done = move(env, task, target_pos=above_plate)
        if done:
            print("[Task] Task ended after move to plate!")
            return

        # Step 11: place first tomato on plate
        print("[Step 11] Place first tomato on plate")
        obs, reward, done = place(env, task, target_pos=plate_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after placing tomato1!")
            return

        # Step 12: move to above second tomato (item2)
        print("[Step 12] Move to above second tomato (item2)")
        tomato2_pos = positions['item2']
        above_tomato2 = np.array(tomato2_pos) + np.array([0, 0, 0.10])
        obs, reward, done = move(env, task, target_pos=above_tomato2)
        if done:
            print("[Task] Task ended after move to tomato2!")
            return

        # Step 13: pick second tomato
        print("[Step 13] Pick second tomato (item2)")
        obs, reward, done = pick(env, task, target_pos=tomato2_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after picking tomato2!")
            return

        # Step 14: move to above plate
        print("[Step 14] Move to above plate (again)")
        obs, reward, done = move(env, task, target_pos=above_plate)
        if done:
            print("[Task] Task ended after move to plate (again)!")
            return

        # Step 15: place second tomato on plate
        print("[Step 15] Place second tomato on plate")
        obs, reward, done = place(env, task, target_pos=plate_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after placing tomato2!")
            return

        # (Extra) Repeat for third tomato (item3)
        print("[Extra] Move to above third tomato (item3)")
        tomato3_pos = positions['item3']
        above_tomato3 = np.array(tomato3_pos) + np.array([0, 0, 0.10])
        obs, reward, done = move(env, task, target_pos=above_tomato3)
        if done:
            print("[Task] Task ended after move to tomato3!")
            return

        print("[Extra] Pick third tomato (item3)")
        obs, reward, done = pick(env, task, target_pos=tomato3_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after picking tomato3!")
            return

        print("[Extra] Move to above plate (final)")
        obs, reward, done = move(env, task, target_pos=above_plate)
        if done:
            print("[Task] Task ended after move to plate (final)!")
            return

        print("[Extra] Place third tomato on plate")
        obs, reward, done = place(env, task, target_pos=plate_pos, approach_distance=0.10, approach_axis='z')
        if done:
            print("[Task] Task ended after placing tomato3!")
            return

        print("[Task] All tomatoes placed on plate. Task completed successfully! Reward:", reward)

    except Exception as e:
        print(f"[Task] Exception occurred: {e}")
    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()