# run_combined_task.py
#
# Combined, specification-compliant task:
#   1) Rotate the gripper by 90 ° (→ “ninety_deg”)
#   2) Move to the drawer’s side position
#   3) Move to the drawer’s anchor / handle
#   4) Grasp the handle (PDDL “pick-drawer” ≈ skill_code.pick on the anchor)
#   5) Pull the drawer open
#   6) Pick the rubbish that is still on the table
#   7) Place the rubbish into the bin

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape                     # kept for completeness
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pick, pull, place    # predefined skills
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
#  Small helpers                                                              #
# --------------------------------------------------------------------------- #
def _angle_name_to_quaternion(name: str) -> np.ndarray:
    """Convert symbolic angle names (zero_deg, ninety_deg, …) to xyzw quaternions."""
    name = name.lower()
    if name == 'zero_deg':
        return np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32)
    if name == 'ninety_deg':
        # 90 ° about the z-axis
        return R.from_euler('z', 90, degrees=True).as_quat().astype(np.float32)
    raise ValueError(f"Unknown angle symbol: {name}")


def _derive_pull_axis_and_distance(anchor: np.ndarray,
                                   joint: np.ndarray,
                                   extra_margin: float = 0.02):
    """
    Given two reference points of the drawer (anchor on the handle
    and an inner joint point), return
        • the principal axis along which the drawer translates
        • a conservative pull distance
    """
    delta = anchor - joint
    major = int(np.argmax(np.abs(delta)))           # 0:x, 1:y, 2:z
    axis_letter = 'xyz'[major]
    axis_sign   = '-' if delta[major] < 0 else ''
    axis        = axis_sign + axis_letter
    distance    = abs(delta[major]) + extra_margin  # add a safety margin
    return axis, distance


def _fetch_or_die(dic: dict, *keys):
    """Ensure every requested key exists in the dictionary and is non-None."""
    missing = [k for k in keys if k not in dic or dic[k] is None]
    if missing:
        raise KeyError(f"[Task] Missing object positions for: {', '.join(missing)}")
    return [dic[k] for k in keys]


def _pick_first_available(dic: dict, *candidates):
    """Return the first candidate key that exists (and is not None) in the dict."""
    for key in candidates:
        if key in dic and dic[key] is not None:
            return key
    raise KeyError(f"[Task] None of the candidate keys present: {candidates}")


# --------------------------------------------------------------------------- #
#  Main task logic                                                            #
# --------------------------------------------------------------------------- #
def run_combined_task():
    print("===== Combined Drawer-Opening & Disposal Task =====")

    env, task = setup_environment()
    try:
        # ------------------------------------------------ Environment reset
        descriptions, obs = task.reset()

        # --------- Optional: video / recording hooks
        init_video_writers(obs)
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------ Retrieve positions
        pos = get_object_positions()               # dict[str, np.ndarray]

        # Drawer reference points (use the *bottom* drawer, assumed unlocked)
        side_pos, anchor_pos, joint_pos = _fetch_or_die(
            pos,
            'bottom_side_pos',
            'bottom_anchor_pos',
            'bottom_joint_pos'
        )

        # Rubbish item & bin – allow for naming variations
        rubbish_key = _pick_first_available(
            pos,
            'rubbish', 'item3', 'item2', 'item1'
        )
        rubbish_pos = pos[rubbish_key]
        bin_pos     = pos['bin']

        # ------------------------------------------------ Execute oracle plan
        done   = False
        reward = 0.0

        # ── 1) Rotate gripper to ninety_deg ────────────────────────────────
        print("[Plan] 1/7  – rotate → ninety_deg")
        obs, reward, done = rotate(
            env, task,
            target_quat=_angle_name_to_quaternion('ninety_deg')
        )
        if done:
            return

        # ── 2) Move to the drawer’s side position ──────────────────────────
        print("[Plan] 2/7  – move → bottom_side_pos")
        obs, reward, done = move(
            env, task,
            target_pos=side_pos
        )
        if done:
            return

        # ── 3) Move to the drawer’s anchor / handle ────────────────────────
        print("[Plan] 3/7  – move → bottom_anchor_pos")
        obs, reward, done = move(
            env, task,
            target_pos=anchor_pos
        )
        if done:
            return

        # ── 4) Grasp the handle (≈ pick-drawer) ────────────────────────────
        print("[Plan] 4/7  – grasp handle @ anchor")
        obs, reward, done = pick(
            env, task,
            target_pos=anchor_pos,
            approach_distance=0.06,    # slightly closer than default
            approach_axis='z'
        )
        if done:
            return

        # ── 5) Pull the drawer open ────────────────────────────────────────
        axis, dist = _derive_pull_axis_and_distance(anchor_pos, joint_pos)
        print(f"[Plan] 5/7  – pull along {axis} by {dist:.3f} m")
        obs, reward, done = pull(
            env, task,
            pull_distance=dist,
            pull_axis=axis
        )
        if done:
            return

        # (Optional) Release the handle so the gripper can open again
        try:
            open_cmd = np.zeros(env.action_shape)
            open_cmd[-1] = 1.0         # open gripper
            for _ in range(5):
                obs, reward, done = task.step(open_cmd)
                if done:
                    return
        except Exception:
            # Safe to ignore – purely convenience
            pass

        # ── 6) Pick the rubbish from the table ─────────────────────────────
        print(f"[Plan] 6/7  – pick {rubbish_key} @ table")
        obs, reward, done = pick(
            env, task,
            target_pos=rubbish_pos
        )
        if done:
            return

        # ── 7) Place the rubbish into the bin ──────────────────────────────
        print("[Plan] 7/7  – place rubbish → bin")
        obs, reward, done = place(
            env, task,
            target_pos=bin_pos
        )
        if done:
            return

        # ------------------------------------------------ Finished successfully
        print("[Task] Plan executed successfully – final reward:", reward)

    except Exception as exc:
        print("[Task] Exception during task execution:", exc)

    finally:
        shutdown_environment(env)
        print("===== Task finished – environment closed =====")


# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    run_combined_task()