import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pull, pick, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def run_skeleton_task():
    """Run the complete drawer‑opening and tomato‑relocation task."""
    print("===== Starting Task =====")

    # --------------------------------------------------------------------------
    #  Environment setup
    # --------------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()        # RLBench reset
        init_video_writers(obs)                 # (Video) optional

        # Wrap RLBench’s step / get_observation so recording works automatically
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ----------------------------------------------------------------------
        #  Retrieve all useful 3‑D positions from the scene
        # ----------------------------------------------------------------------
        positions = get_object_positions()

        # Drawer we are going to open (choose the bottom one if available)
        drawer_side_key   = 'bottom_side_pos'
        drawer_anchor_key = 'bottom_anchor_pos'

        # If the bottom drawer does not exist, gracefully fall back to middle
        if drawer_side_key not in positions or drawer_anchor_key not in positions:
            drawer_side_key   = 'middle_side_pos'
            drawer_anchor_key = 'middle_anchor_pos'

        side_pos   = np.array(positions[drawer_side_key],   dtype=float)
        anchor_pos = np.array(positions[drawer_anchor_key], dtype=float)

        # Plate position (final destination)
        plate_pos = np.array(positions.get('plate'), dtype=float)

        # Tomato / item positions (filter any key that starts with “item”)
        tomato_names = sorted([k for k in positions if k.startswith('item')])
        tomato_positions = [np.array(positions[n], dtype=float) for n in tomato_names]

        # Make sure at least one tomato exists
        if not tomato_positions:
            raise RuntimeError("No tomatoes/items found in object_positions.")

        # ----------------------------------------------------------------------
        #  STEP 1 – rotate the gripper 90° so it can align with the drawer
        # ----------------------------------------------------------------------
        target_quat = R.from_euler('xyz', [0, 0, np.pi / 2]).as_quat()  # 90° about Z
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Task] Finished right after rotate (unexpected).")
            return

        # ----------------------------------------------------------------------
        #  STEP 2 & 3 – move to the drawer handle
        # ----------------------------------------------------------------------
        obs, reward, done = move(env, task, side_pos)   # move to side
        if done:
            print("[Task] Finished unexpectedly while moving to drawer side.")
            return

        obs, reward, done = move(env, task, anchor_pos)  # move to anchor
        if done:
            print("[Task] Finished unexpectedly while moving to drawer anchor.")
            return

        # ----------------------------------------------------------------------
        #  STEP 4 – pull the drawer open
        # ----------------------------------------------------------------------
        obs, reward, done = pull(env, task,
                                 pull_distance=0.25,  # pull ~25 cm
                                 pull_axis='x')       # assumes +X pulls outward
        if done:
            print("[Task] Finished unexpectedly while pulling drawer.")
            return

        # ----------------------------------------------------------------------
        #  STEP 5 … – loop through every tomato, pick it, place on plate
        # ----------------------------------------------------------------------
        for idx, (t_name, t_pos) in enumerate(zip(tomato_names, tomato_positions), start=1):
            print(f"[Task] Handling {t_name} ({idx}/{len(tomato_names)})")

            # Move above the tomato (slightly offset on Z so we approach from above)
            tomato_above = np.array(t_pos)
            tomato_above[2] += 0.10  # +10 cm in Z
            obs, reward, done = move(env, task, tomato_above)
            if done:
                print("[Task] Finished unexpectedly while approaching tomato.")
                return

            # STEP – pick
            obs, reward, done = pick(env, task, target_pos=t_pos)
            if done:
                print("[Task] Finished unexpectedly while picking tomato.")
                return

            # Move above plate first (safe trajectory)
            plate_above = np.array(plate_pos)
            plate_above[2] += 0.10
            obs, reward, done = move(env, task, plate_above)
            if done:
                print("[Task] Finished unexpectedly while moving above plate.")
                return

            # STEP – place
            obs, reward, done = place(env, task, target_pos=plate_pos)
            if done:
                print("[Task] Finished unexpectedly while placing tomato.")
                return

        print("[Task] All tomatoes moved to plate. SUCCESS!")

    finally:
        # Make sure simulator shuts down even if exceptions occur
        shutdown_environment(env)

    print("===== End of Task =====")


if __name__ == "__main__":
    run_skeleton_task()