# run_combined_task.py

import numpy as np
from scipy.spatial.transform import Rotation as SciRot
from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ==== Pre‑implemented skills ====
from skill_code import rotate, move, pull, pick, place


def _compute_target_quat_z_90(obs):
    """
    Compute a target quaternion that is rotated +90° about the world‑Z axis
    w.r.t. the current gripper orientation.
    """
    start_quat_xyzw = obs.gripper_pose[3:7]
    start_rot = SciRot.from_quat(start_quat_xyzw)
    rot_z90 = SciRot.from_euler('z', 90.0, degrees=True)
    target_rot = rot_z90 * start_rot
    return target_rot.as_quat()  # xyzw


def run_combined_task():
    print("===== [Task] Unlock‑and‑Dispose =====")

    # -------------------------------------------------------------
    # 1) Environment & Video initialisation
    # -------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)

        # wrap step / get_observation so that video frames are captured
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------------------------------------------------
        # 2) Fetch all useful object positions from the scene
        # ---------------------------------------------------------
        positions = get_object_positions()

        # Mapping from spec names -> real RLBench object keys
        # (Some names differ slightly between PDDL & RLBench)
        obj_key = {
            'side-pos-bottom':  'bottom_side_pos',
            'anchor-pos-bottom': 'bottom_anchor_pos',
            'bin': 'bin',
            'rubbish': 'rubbish' if 'rubbish' in positions else 'item3'
        }

        try:
            side_pos_bottom   = positions[obj_key['side-pos-bottom']]
            anchor_pos_bottom = positions[obj_key['anchor-pos-bottom']]
            bin_pos           = positions[obj_key['bin']]
            rubbish_pos       = positions[obj_key['rubbish']]
        except KeyError as ke:
            raise RuntimeError(f"[Task] Required key missing in object_positions(): {ke}")

        # ---------------------------------------------------------
        # 3) Execute oracle plan (Specification)
        # ---------------------------------------------------------

        # STEP‑1 : rotate gripper → 90 deg about Z
        print("\n--- PLAN Step‑1 : rotate gripper 90° ---")
        target_quat = _compute_target_quat_z_90(task.get_observation())
        obs, reward, done = rotate(
            env, task,
            target_quat=target_quat,
            max_steps=120, threshold=0.04, timeout=12.0
        )
        if done:
            print("[Task] Finished early after rotate"); return

        # STEP‑2 : move gripper to bottom side position
        print("\n--- PLAN Step‑2 : move → side‑pos‑bottom ---")
        obs, reward, done = move(
            env, task,
            target_pos=side_pos_bottom,
            max_steps=120, threshold=0.01, timeout=12.0
        )
        if done:
            print("[Task] Finished early after move‑to‑side"); return

        # STEP‑3 : move gripper to bottom anchor position
        print("\n--- PLAN Step‑3 : move → anchor‑pos‑bottom ---")
        obs, reward, done = move(
            env, task,
            target_pos=anchor_pos_bottom,
            max_steps=120, threshold=0.01, timeout=12.0
        )
        if done:
            print("[Task] Finished early after move‑to‑anchor"); return

        # STEP‑4 : pull the drawer handle outward (open drawer)
        print("\n--- PLAN Step‑4 : pull drawer bottom ---")
        # Heuristic values for pull distance/axis – tweak if necessary
        obs, reward, done = pull(
            env, task,
            pull_distance=0.12,
            pull_axis='x',         # pull along positive‑X in most RLBench drawer setups
            max_steps=120, threshold=0.01, timeout=12.0
        )
        if done:
            print("[Task] Finished early after pull"); return

        # STEP‑5 : pick the rubbish object from table
        print("\n--- PLAN Step‑5 : pick rubbish ---")
        obs, reward, done = pick(
            env, task,
            target_pos=rubbish_pos,
            approach_distance=0.15,
            max_steps=120, threshold=0.01,
            approach_axis='z', timeout=12.0
        )
        if done:
            print("[Task] Finished early after pick"); return

        # STEP‑6 : place rubbish into bin
        print("\n--- PLAN Step‑6 : place rubbish in bin ---")
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos,
            approach_distance=0.15,
            max_steps=120, threshold=0.01,
            approach_axis='z', timeout=12.0
        )

        # ---------------------------------------------------------
        # 4) Final status
        # ---------------------------------------------------------
        if done:
            print(f"[Task] SUCCESS! Final reward: {reward}")
        else:
            print("[Task] Plan finished but environment signalled done=False.")

    finally:
        shutdown_environment(env)
        print("===== [Task] Shutdown complete =====")


if __name__ == "__main__":
    run_combined_task()