# run_combined_task.py
#
# This script executes the oracle plan specified in the “Specification” section.
# Only the predefined skills (rotate, move, pull, pick, place) are used.

import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# Import the predefined, already‑implemented skills
from skill_code import pick, place, move, rotate, pull


def _as_np(obj_pos):
    """Utility – turn any iterable (or RLBench Vector3) into np.ndarray(float32)."""
    return np.asarray(obj_pos, dtype=np.float32)


def _compute_pull_axis_and_distance(anchor, joint, safety_margin=0.05):
    """
    Decide which axis to pull along (x / y / z, positive or negative) and
    return both the axis string and a distance that is slightly larger than
    the closed‑drawer offset so that the drawer surely opens.

    Args:
        anchor (np.ndarray): Handle position (anchor‑pos).
        joint  (np.ndarray): Drawer joint/pivot position (joint‑pos).
        safety_margin (float): Extra distance to guarantee full opening.

    Returns:
        axis_str (str): One of {'x', '-x', 'y', '-y', 'z', '-z'}.
        distance (float): Positive distance to pass to the pull() skill.
    """
    diff = anchor - joint
    # Identify the dominant axis
    axis_idx = np.argmax(np.abs(diff))
    axis_str = ['x', 'y', 'z'][axis_idx]
    if diff[axis_idx] < 0:
        axis_str = '-' + axis_str
    distance = abs(diff[axis_idx]) + safety_margin
    return axis_str, distance


def run_combined_task():
    """Execute the oracle plan step‑by‑step."""
    print("===== [TASK] Start Combined‑Domain Task =====")

    # --------------------------------------------------
    # 1) Environment initialisation
    # --------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) set up video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)                # wrap step()
        task.get_observation = recording_get_observation(    # wrap get_observation()
            task.get_observation
        )

        # --------------------------------------------------
        # 2) Gather all object positions needed for the plan
        # --------------------------------------------------
        pos_dict = {k: _as_np(v) for k, v in get_object_positions().items()}

        # Drawer related positions
        side_pos   = pos_dict['bottom_side_pos']
        anchor_pos = pos_dict['bottom_anchor_pos']
        joint_pos  = pos_dict['bottom_joint_pos']

        # Tomatoes & plate
        tomato1_pos = pos_dict['tomato1']
        tomato2_pos = pos_dict['tomato2']
        plate_pos   = pos_dict['plate']

        # --------------------------------------------------
        # 3) Execute oracle plan (Specification)
        # --------------------------------------------------
        done = False

        # Step‑1  rotate gripper from zero_deg → ninety_deg (about +Z by 90°)
        print("\n[Plan‑Step 1] rotate gripper → 90° about Z")
        target_quat = R.from_euler('z', 90, degrees=True).as_quat()
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Task] Finished early after rotate.")
            return

        # Step‑2  move‑to‑side (drawer front)
        print("\n[Plan‑Step 2] move to drawer side‑pos:", side_pos)
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[Task] Finished early after move‑to‑side.")
            return

        # Step‑3  move‑to‑anchor (grab handle)
        print("\n[Plan‑Step 3] move to drawer anchor‑pos:", anchor_pos)
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("[Task] Finished early after move‑to‑anchor.")
            return

        # Step‑4  pick‑drawer (close gripper on handle)
        print("\n[Plan‑Step 4] pick (grasp drawer handle at anchor‑pos)")
        obs, reward, done = pick(env, task, anchor_pos, approach_distance=0.05)
        if done:
            print("[Task] Finished early after pick‑drawer.")
            return

        # Step‑5  pull the drawer open
        print("\n[Plan‑Step 5] pull drawer open")
        axis_str, distance = _compute_pull_axis_and_distance(anchor_pos, joint_pos)
        print(f"         Chosen pull axis = {axis_str}, distance = {distance:.3f}")
        obs, reward, done = pull(env, task, pull_distance=distance, pull_axis=axis_str)
        if done:
            print("[Task] Finished early after pull.")
            return

        # Step‑6  pick tomato1
        print("\n[Plan‑Step 6] pick tomato1 at table:", tomato1_pos)
        obs, reward, done = pick(env, task, tomato1_pos)
        if done:
            print("[Task] Finished early after picking tomato1.")
            return

        # Step‑7  place tomato1 on plate
        print("\n[Plan‑Step 7] place tomato1 on plate:", plate_pos)
        obs, reward, done = place(env, task, plate_pos)
        if done:
            print("[Task] Finished early after placing tomato1.")
            return

        # Step‑8  pick tomato2
        print("\n[Plan‑Step 8] pick tomato2 at table:", tomato2_pos)
        obs, reward, done = pick(env, task, tomato2_pos)
        if done:
            print("[Task] Finished early after picking tomato2.")
            return

        # Step‑9  place tomato2 on plate
        print("\n[Plan‑Step 9] place tomato2 on plate:", plate_pos)
        obs, reward, done = place(env, task, plate_pos)
        if done:
            print("[Task] Task successfully finished!")
        else:
            print("[Task] Plan executed, but environment signalled done = False.")

    finally:
        # --------------------------------------------------
        # 4) Always shut the environment down properly
        # --------------------------------------------------
        shutdown_environment(env)
        print("===== [TASK] Shutdown complete =====")


if __name__ == "__main__":
    run_combined_task()