# run_combined_task.py

import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ---- import available skills ------------------------------------------------
from skill_code import rotate, move, pick, pull, place
# -----------------------------------------------------------------------------

# ----------------------------------------------------------------------------- 
# Helper utilities
# -----------------------------------------------------------------------------
def _extract_position(obj_info):
    """Try to obtain a 3‑D position from different possible formats."""
    if obj_info is None:
        raise ValueError("Object information is None.")
    # dict style: {"position": np.array, "quat": np.array, ...}
    if isinstance(obj_info, dict):
        for key in ("position", "pos", "xyz", "translation"):
            if key in obj_info:
                return np.asarray(obj_info[key])
        # maybe (pos, quat)
        if "pose" in obj_info and len(obj_info["pose"]) >= 3:
            return np.asarray(obj_info["pose"][:3])
    # tuple / list / ndarray
    if isinstance(obj_info, (list, tuple, np.ndarray)):
        arr = np.asarray(obj_info)
        if arr.size >= 3:
            return arr[:3]
    raise ValueError(f"Could not parse position from {obj_info}")


def _extract_quat(obj_info):
    """Try to obtain a quaternion (x,y,z,w) from different possible formats."""
    if obj_info is None:
        raise ValueError("Object information is None.")
    if isinstance(obj_info, dict):
        for key in ("quaternion", "quat", "orientation", "orn"):
            if key in obj_info:
                q = np.asarray(obj_info[key])
                if q.size == 4:
                    return q
        if "pose" in obj_info and len(obj_info["pose"]) >= 7:
            return np.asarray(obj_info["pose"][3:7])
    if isinstance(obj_info, (list, tuple, np.ndarray)):
        arr = np.asarray(obj_info)
        if arr.size == 7:
            return arr[3:7]
        if arr.size == 4:
            return arr
    raise ValueError(f"Could not parse quaternion from {obj_info}")


def _quat_z(angle_deg):
    """Quaternion for pure Z‑axis rotation (degrees)."""
    return R.from_euler('z', angle_deg, degrees=True).as_quat()  # xyzw


# -----------------------------------------------------------------------------
# Mapping from specification names → real object names used in RLBench
# -----------------------------------------------------------------------------
SPEC2OBJ = {
    "nowhere-pos": "waypoint1",            # generic waiting pose
    "side-pos-bottom": "bottom_side_pos",
    "anchor-pos-bottom": "bottom_anchor_pos",
    "rubbish": "rubbish",
    "bin": "bin",
}


# -----------------------------------------------------------------------------
# Main task runner (executes the oracle plan in specification order)
# -----------------------------------------------------------------------------
def run_combined_task():
    print("===== [Task] Combined Drawer & Disposal =====")

    env, task = setup_environment()
    try:
        # reset environment / task
        descriptions, obs = task.reset()

        # optional video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # fetch all object positions/quats once at reset (will be updated later
        # if the object is expected to move, we re‑query on‑demand)
        scene_info = get_object_positions()

        # convenience lambda to always get *fresh* positions for movable objects
        def pos(name):
            return _extract_position(get_object_positions()[name])

        # ------------------------------------------------------------------
        # STEP‑BY‑STEP EXECUTION (matches provided specification)
        # ------------------------------------------------------------------
        done = False
        reward = 0.0

        # Step 1: rotate gripper from zero_deg to ninety_deg
        target_quat = _quat_z(90)   # 90‑degree rotation about z‑axis
        print("\n[Step 1] rotate gripper → 90°")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Early Termination] after rotate.")
            return

        # Step 2: move gripper to side position of bottom drawer
        print("\n[Step 2] move to bottom_side_pos")
        side_pos = pos(SPEC2OBJ["side-pos-bottom"])
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[Early Termination] after move‑to‑side.")
            return

        # Step 3: move gripper to anchor position (handle) of bottom drawer
        print("\n[Step 3] move to bottom_anchor_pos")
        anchor_pos = pos(SPEC2OBJ["anchor-pos-bottom"])
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("[Early Termination] after move‑to‑anchor.")
            return

        # Step 4: pick the drawer handle (equivalent to pick‑drawer)
        print("\n[Step 4] pick drawer handle")
        obs, reward, done = pick(env, task, target_pos=anchor_pos,
                                 approach_distance=0.07,
                                 approach_axis='y')   # assume handle approached from y
        if done:
            print("[Early Termination] after pick‑drawer.")
            return

        # Step 5: pull drawer outward (open)
        print("\n[Step 5] pull drawer outward")
        # positive X in RLBench usually opens drawer; distance 0.18 m is typical
        obs, reward, done = pull(env, task, pull_distance=0.18, pull_axis='x')
        if done:
            print("[Early Termination] after pull.")
            return

        # (Optional) slight retreat to avoid collision with drawer front
        safe_pos = anchor_pos + np.array([0.05, 0.0, 0.05])
        obs, reward, done = move(env, task, safe_pos)
        if done:
            print("[Early Termination] after retreat.")
            return

        # Step 6: pick rubbish from table
        print("\n[Step 6] pick rubbish on table")
        rubbish_pos = pos(SPEC2OBJ["rubbish"])
        obs, reward, done = pick(env, task, target_pos=rubbish_pos,
                                 approach_distance=0.12,
                                 approach_axis='z')
        if done:
            print("[Early Termination] after picking rubbish.")
            return

        # Step 7: place rubbish into bin
        print("\n[Step 7] place rubbish into bin")
        bin_pos = pos(SPEC2OBJ["bin"])
        obs, reward, done = place(env, task, target_pos=bin_pos,
                                  approach_distance=0.12,
                                  approach_axis='z')
        if done:
            print("[Task Completed] Trash disposed successfully! Reward:", reward)
        else:
            print("[Task Finished] Execution ended without 'done' flag.")

    finally:
        shutdown_environment(env)
        print("===== [Task] Environment shutdown complete =====")


if __name__ == "__main__":
    run_combined_task()