# run_skeleton_task.py (Completed Executable Script)


import time
import traceback
import numpy as np

from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *          # → pick, place, move, rotate, pull
from video import (init_video_writers,
                   recording_step,
                   recording_get_observation)
from object_positions import get_object_positions



# ---------------------------------------------------------
# Helper utilities – kept completely local to this module
# ---------------------------------------------------------
def euler_to_quat(roll: float, pitch: float, yaw: float):
    """Convert Euler XYZ => quaternion (xyzw)."""
    cr, sr = np.cos(roll / 2.0), np.sin(roll / 2.0)
    cp, sp = np.cos(pitch / 2.0), np.sin(pitch / 2.0)
    cy, sy = np.cos(yaw / 2.0), np.sin(yaw / 2.0)

    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy
    w = cr * cp * cy + sr * sp * sy
    return np.array([x, y, z, w], dtype=np.float32)


def safe_skill_call(fn, *args, **kwargs):
    """Wrapper that prints failures instead of crashing the entire run."""
    try:
        return fn(*args, **kwargs)
    except Exception as exc:
        print(f"[WARNING] Skill `{fn.__name__}` threw an exception – "
              f"continuing. \n{traceback.format_exc()}")
        # RLBench’s step-like functions always return (obs, reward, done)
        # Return current observation with zero reward and done=False if failed.
        task = args[1] if len(args) > 1 else None
        obs = task.get_observation() if task else None
        return obs, 0.0, False



# ---------------------------------------------------------
# A very light-weight “exploration” routine.
#
#  Goal – try to discover the *lock-known* predicate that was
#  missing in the original domain.  We do this by:
#     1) walking around (move skill) to every known position
#     2) if an object’s name hints it is a “drawer” or “handle”,
#        we try a pull on it (which, in the exploration domain,
#        makes (lock-known ?obj) true).
# ---------------------------------------------------------
def exploration_phase(env, task, positions):
    print("\n========== [Exploration] START ==========")

    # Sort names so that behaviour is deterministic
    for name in sorted(positions.keys()):
        pos = positions[name]
        print(f"[Exploration] Visiting `{name}` @ {np.round(pos, 3)}")

        # 1) Move near the object/location
        obs, reward, done = safe_skill_call(
            move, env, task,
            target_pos=pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis='z',
            timeout=5.0
        )
        if done:
            print("[Exploration] Environment signalled termination during move.")
            break

        # 2) Identify drawers/handles and try to pull once
        lower = name.lower()
        if ("drawer" in lower) or ("handle" in lower):
            print(f"[Exploration] `{name}` looks like a drawer/handle – "
                  f"trying to pull (→ discover lock-known).")

            # In many RLBench tasks the gripper must first be rotated
            # so the fingers are vertical when grasping a drawer handle.
            target_quat = euler_to_quat(np.pi / 2.0, 0.0, 0.0)   # 90° around X
            obs, reward, done = safe_skill_call(
                rotate, env, task,
                target_quat=target_quat,
                max_steps=80,
                threshold=0.1,
                timeout=5.0
            )
            if done: break

            # Attempt a pull – this should internally grasp the handle
            # (some RLBench skills combine pick + pull; here we just call pull)
            obs, reward, done = safe_skill_call(pull, env, task)
            if done: break

    print("========== [Exploration] DONE ==========\n")



# ---------------------------------------------------------
# An extremely simple “main task” routine that demonstrates
# usage of the built-in skills.  For most benchmark tasks
# the object to manipulate is called something like
#   • `target_object`   – we pick it up
#   • `goal_position`   – we place it there
# ---------------------------------------------------------
def main_task(env, task, positions):
    print("\n========== [Main Task] START ==========")

    # Try to infer plausible keys
    target_obj_name = None
    goal_pos_name = None
    for key in positions:
        if 'target' in key.lower():
            target_obj_name = key
        if 'goal' in key.lower() or 'bin' in key.lower() or 'basket' in key.lower():
            goal_pos_name = key

    # Fallbacks if heuristics failed
    if target_obj_name is None:
        # Pick the first non-drawer object
        target_obj_name = next(k for k in positions
                               if not (('drawer' in k.lower()) or
                                       ('handle' in k.lower())))

    if goal_pos_name is None:
        # Choose an arbitrary location farthest from the object
        max_dist = -1.0
        for k, p in positions.items():
            d = np.linalg.norm(positions[target_obj_name] - p)
            if d > max_dist:
                max_dist, goal_pos_name = d, k

    print(f"[Main Task]  target object = {target_obj_name}")
    print(f"[Main Task]  goal location = {goal_pos_name}")

    # ---- 1) Move above the target object
    obs, reward, done = safe_skill_call(
        move, env, task,
        target_pos=positions[target_obj_name],
        approach_distance=0.15,
        max_steps=120,
        threshold=0.01,
        approach_axis='z',
        timeout=5.0
    )
    if done: return

    # ---- 2) Pick
    obs, reward, done = safe_skill_call(
        pick, env, task,
        target_pos=positions[target_obj_name],
        approach_distance=0.03,
        lift_up_distance=0.15,
        max_steps=150,
        threshold=0.01,
        approach_axis='z',
        timeout=5.0
    )
    if done: return

    # ---- 3) Move over goal & place
    obs, reward, done = safe_skill_call(
        move, env, task,
        target_pos=positions[goal_pos_name],
        approach_distance=0.15,
        max_steps=120,
        threshold=0.01,
        approach_axis='z',
        timeout=5.0
    )
    if done: return

    obs, reward, done = safe_skill_call(
        place, env, task,
        target_pos=positions[goal_pos_name],
        release_height=0.05,
        max_steps=150,
        threshold=0.01,
        approach_axis='z',
        timeout=5.0
    )
    if done: return

    print("========== [Main Task] DONE ==========")


# ---------------------------------------------------------
# Entry-point
# ---------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        #  Video recording setup
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # -------------------------------------------------
        # Retrieve object positions from the helper module
        # -------------------------------------------------
        positions = get_object_positions()
        print(f"[Info] Discovered {len(positions)} objects from "
              f"`object_positions`: {list(positions.keys())}")

        # -------------------------------------------------
        # 1) Exploration phase  → reveal missing predicate
        # -------------------------------------------------
        exploration_phase(env, task, positions)

        # -------------------------------------------------
        # 2) Main demonstrative task
        # -------------------------------------------------
        main_task(env, task, positions)

    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")



if __name__ == "__main__":
    run_skeleton_task()