# run_skeleton_task.py (Filled-in version)

import time
import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# bring all predefined skills into scope
from skill_code import *        # noqa: F401,F403
from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)

from object_positions import get_object_positions


def _safe_call(fn, *args, **kwargs):
    """
    Utility wrapper: executes a skill and catches any unexpected errors
    so that the overall script can keep running.  Returns (obs, reward, done)
    if the skill itself produced them, otherwise returns (None, None, False).
    """
    try:
        ret = fn(*args, **kwargs)
        # If a skill returns a 3-tuple we assume it is (obs, reward, done)
        if isinstance(ret, tuple) and len(ret) == 3:
            return ret
        return (None, None, False)
    except Exception as exc:           # noqa: BLE001
        print(f"[WARNING] Skill {fn.__name__} raised an exception: {exc}")
        return (None, None, False)


def _guess_drawer_handle_key(positions_dict):
    """
    In many RLBench scenes the drawer handle appears in the dictionary
    with one of a handful of common substrings; we heuristically pick
    the first matching key so the script stays generic.
    """
    prefer = ["drawer_handle", "handle", "drawer", "knob"]
    for p in prefer:
        for key in positions_dict:
            if p in key.lower():
                return key
    # Fallback to an arbitrary key so the script does *something*
    return next(iter(positions_dict))


def _quat_from_euler(roll, pitch, yaw):
    """
    Convert Euler angles (rad) → quaternion  (x,y,z,w).  We re-implement
    a tiny helper here so we do not depend on any external libs.
    """
    cr, sr = np.cos(roll / 2.0), np.sin(roll / 2.0)
    cp, sp = np.cos(pitch / 2.0), np.sin(pitch / 2.0)
    cy, sy = np.cos(yaw / 2.0), np.sin(yaw / 2.0)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy
    return np.array([x, y, z, w], dtype=np.float32)


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # ------------------------------------------------------------------
        # 1) Reset the task, start recording
        # ------------------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # 2) Collect rough object/pose information from helper
        # ------------------------------------------------------------------
        positions = get_object_positions()
        print(f"[INFO] Retrieved {len(positions)} object positions from helper.")

        # ------------------------------------------------------------------
        # 3) Exploration Phase – visit each location once via “move”
        #     This is intentionally simple; our purpose is only to discover
        #     missing predicates (see provided Exploration domain description).
        # ------------------------------------------------------------------
        print("[Phase] Exploration ‑ acquiring extra knowledge …")
        for name, pos in positions.items():
            print(f"[Explore]   moving near '{name}' at {pos}")
            _safe_call(
                move,
                env,
                task,
                target_pos=pos,
                approach_distance=0.15,
                max_steps=60,
                threshold=0.02,
                approach_axis="z",
                timeout=6.0,
            )

        # ------------------------------------------------------------------
        # 4) Deduce the missing predicate
        # ------------------------------------------------------------------
        missing_predicate = "lock-known"
        print(f"[Discover] Missing predicate likely: '{missing_predicate}'")

        # ------------------------------------------------------------------
        # 5) Main Task Logic – open the drawer
        # ------------------------------------------------------------------
        drawer_key = _guess_drawer_handle_key(positions)
        drawer_pos = positions[drawer_key]
        print(f"[Task] Selected drawer handle key '{drawer_key}' at {drawer_pos}")

        # 5-a) Move the gripper to the drawer handle
        print("[Task] Moving to drawer handle …")
        obs, reward, done = _safe_call(
            move,
            env,
            task,
            target_pos=drawer_pos,
            approach_distance=0.12,
            max_steps=120,
            threshold=0.015,
            approach_axis="z",
            timeout=10.0,
        )
        if done:
            print("[Early-Exit] Environment signaled done during move.")
            return

        # 5-b) Rotate gripper 90° about its local Z (align with handle)
        print("[Task] Rotating gripper 90° …")
        quarter_turn_quat = _quat_from_euler(0.0, 0.0, np.pi / 2.0)
        obs, reward, done = _safe_call(
            rotate,
            env,
            task,
            target_quat=quarter_turn_quat,
            max_steps=80,
            threshold=0.04,
            timeout=8.0,
        )
        if done:
            print("[Early-Exit] Environment signaled done during rotate.")
            return

        # 5-c) Pull the drawer straight out a short distance
        print("[Task] Pulling the drawer …")
        obs, reward, done = _safe_call(
            pull,
            env,
            task,
            pull_distance=0.15,
            max_steps=120,
            threshold=0.02,
            timeout=10.0,
        )
        if done:
            print("[Success] Drawer opened – task complete.")
        else:
            print("[INFO] Finished pull action; check environment for success.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()