import numpy as np
from scipy.spatial.transform import Rotation as R

from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# === Pre-defined skill functions ===
from skill_code import rotate, move, pull, pick, place

# === Video / logging helpers ===
from video import init_video_writers, recording_step, recording_get_observation

# === Utility for retrieving positions of relevant objects in the scene ===
from object_positions import get_object_positions


def _vector_to_axis_str(vec: np.ndarray) -> str:
    """Return the pull-axis string (‘x’, ‘-x’, ‘y’, …) that best matches vec."""
    axis_idx = int(np.argmax(np.abs(vec)))
    axis_chr = ['x', 'y', 'z'][axis_idx]
    return ('' if vec[axis_idx] >= 0 else '-') + axis_chr


def run_skeleton_task() -> None:
    """Execute the oracle plan that opens the bottom drawer and puts tomatoes on the plate."""
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    # 1) Set up simulation environment
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optionally start video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)                      # wrap step
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # 2) Retrieve pose information for all relevant objects
        # ------------------------------------------------------------------
        positions = get_object_positions()

        # Mandatory keys for this task
        required_keys = [
            'bottom_side_pos', 'bottom_anchor_pos', 'bottom_joint_pos',
            'tomato1', 'tomato2', 'plate'
        ]
        missing = [k for k in required_keys if k not in positions]
        if missing:
            raise KeyError(f"[Task] Missing object positions for: {missing}")

        # Convert to numpy arrays for convenience
        bottom_side   = np.asarray(positions['bottom_side_pos'],   dtype=float)
        bottom_anchor = np.asarray(positions['bottom_anchor_pos'], dtype=float)
        bottom_joint  = np.asarray(positions['bottom_joint_pos'],  dtype=float)
        plate_pos     = np.asarray(positions['plate'],             dtype=float)

        tomato_names  = ['tomato1', 'tomato2']
        tomato_pos    = {name: np.asarray(positions[name], dtype=float) for name in tomato_names}

        # ------------------------------------------------------------------
        # 3) Execute Oracle Plan (specification steps 1-9)
        # ------------------------------------------------------------------
        done = False

        # --- Step 1 : rotate gripper from zero_deg → ninety_deg
        #              (90° about Y-axis aligns fingers with drawer handle)
        target_quat = R.from_euler('y', 90, degrees=True).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Task] Finished during rotate.")
            return

        # --- Step 2 : move-to-side  → use generic ‘move’ skill
        obs, reward, done = move(env, task, bottom_side)
        if done:
            print("[Task] Finished during move-to-side.")
            return

        # --- Step 3 : move-to-anchor → again via ‘move’
        obs, reward, done = move(env, task, bottom_anchor)
        if done:
            print("[Task] Finished during move-to-anchor.")
            return

        # --- Step 4 : pick-drawer   → grip the handle at anchor position
        obs, reward, done = pick(env, task, bottom_anchor)
        if done:
            print("[Task] Finished during pick-drawer.")
            return

        # --- Step 5 : pull the drawer open
        pull_vec        = bottom_anchor - bottom_joint            # direction of drawer slide
        pull_axis_str   = _vector_to_axis_str(pull_vec)
        pull_distance   = np.linalg.norm(pull_vec) + 0.05         # small extra to ensure fully open
        obs, reward, done = pull(env, task, pull_distance, pull_axis=pull_axis_str)
        if done:
            print("[Task] Finished during pull.")
            return

        # After pulling, the gripper is still closed on the handle.
        # The next pick() call automatically issues an “open” command at start,
        # so no explicit release is required.

        # --- Steps 6-9 : move each tomato from table to plate
        for name in tomato_names:
            print(f"[Task] Handling {name} …")

            # Step 6 / 8 – pick tomato on table
            obs, reward, done = pick(env, task, tomato_pos[name])
            if done:
                print(f"[Task] Finished while picking {name}.")
                return

            # Step 7 / 9 – place tomato on plate
            obs, reward, done = place(env, task, plate_pos)
            if done:
                print(f"[Task] Finished while placing {name}.")
                return

        print("[Task] Oracle plan completed (environment did not signal done flag).")

    finally:
        # Always shut down environment to free resources
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()