# run_skeleton_task.py (Completed)

import time
import inspect
import numpy as np
from pyrep.objects.shape import Shape          # kept – required by skeleton
from pyrep.objects.proximity_sensor import ProximitySensor   # kept – required by skeleton

# Environment / skills / helpers (all externally provided)
from env import setup_environment, shutdown_environment
from skill_code import *          # we do NOT redefine any primitive skills
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ----------------------------------------------------------------------------- 
# Utility helpers (safe-define so we do not clash with versions that may already
# exist inside skill_code).  They are NOT “skills”, only small math helpers.
# -----------------------------------------------------------------------------
def _safe_define(name, fn):
    """Put fn into globals() only if name is not already defined."""
    if name not in globals():
        globals()[name] = fn


# ––– quaternion helpers –––
_safe_define('normalize_quaternion',
             lambda q: np.array(q) / (np.linalg.norm(q) + 1e-12))

def _euler_from_quat(q):
    """Return (roll, pitch, yaw) from xyzw quaternion (radians)."""
    x, y, z, w = q
    # roll (x-axis rotation)
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    roll_x  = np.arctan2(t0, t1)

    # pitch (y-axis rotation)
    t2 = +2.0 * (w * y - z * x)
    t2 = +1.0 if t2 > +1.0 else t2
    t2 = -1.0 if t2 < -1.0 else t2
    pitch_y = np.arcsin(t2)

    # yaw (z-axis rotation)
    t3 = +2.0 * (w * z + x * y)
    t4 = +1.0 - 2.0 * (y * y + z * z)
    yaw_z   = np.arctan2(t3, t4)

    return np.array([roll_x, pitch_y, yaw_z])

_safe_define('euler_from_quat', _euler_from_quat)


# ----------------------------------------------------------------------------- 
#  Generic helper to call a skill when we do not know the exact signature.
#  This keeps the code robust to minor differences in skill implementations.
# -----------------------------------------------------------------------------
def call_skill(skill_fn, *extra_args, **extra_kwargs):
    """
    Attempts to call a skill function with (env, task, *args, **kwargs).
    If signature is different, we try progressively to match.
    """
    env, task = extra_args[0], extra_args[1]          # always first two
    remaining = extra_args[2:]                        # the rest (if any)

    try:
        # 1) First try exactly as given
        return skill_fn(*extra_args, **extra_kwargs)
    except TypeError:
        # 2) Try only env, task
        try:
            return skill_fn(env, task)
        except TypeError:
            # 3) Try env, task plus first remaining (common for position)
            if remaining:
                try:
                    return skill_fn(env, task, remaining[0])
                except TypeError:
                    pass
        # 4) Give up – re-raise first error
        raise


# ----------------------------------------------------------------------------- 
#  Main task runner
# -----------------------------------------------------------------------------
def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print("[Info] Objects detected in scene:", list(positions.keys()))

        # We will search for a probable drawer-handle object
        drawer_handle_name = None
        for name in positions.keys():
            if 'handle' in name or 'drawer' in name:
                drawer_handle_name = name
                break

        if drawer_handle_name is None:
            print("[Warning] No drawer handle found – skipping exploration phase!")
        else:
            handle_pos = positions[drawer_handle_name]
            print(f"[Exploration] Target drawer handle: {drawer_handle_name}, "
                  f"pos = {handle_pos}")

            # ---------------------------------------------
            #  Exploration Phase : is the drawer locked?
            # ---------------------------------------------
            #
            #  1) Move end-effector near the drawer handle
            #  2) Try to pick the handle
            #  3) Try to pull open the drawer
            #  4) Check whether drawer opened by measuring object displacement
            #
            #  This determines the missing predicate (“lock-known”) by
            #  inferring if the drawer is locked.
            #

            # 1) Move
            print("[Exploration] Step 1 – move close to handle")
            obs, reward, done = call_skill(move, env, task,
                                           handle_pos,           # target_pos
                                           0.12,                 # approach_distance (fallback)
                                           120,                  # max_steps
                                           0.01,                 # threshold
                                           'xy',                 # approach_axis
                                           10.0)                 # timeout
            if done:
                print("[Exploration] Task finished unexpectedly during move()")
                return

            # 2) Pick
            print("[Exploration] Step 2 – attempt to pick handle")
            obs, reward, done = call_skill(pick, env, task,
                                           handle_pos,           # target_pos / object pos
                                           0.03,                 # approach_distance
                                           120,                  # max_steps
                                           0.005,                # threshold
                                           'z',                  # approach_axis
                                           10.0)                 # timeout
            if done:
                print("[Exploration] Task finished unexpectedly during pick()")
                return

            # 3) Pull
            print("[Exploration] Step 3 – attempt to pull drawer")
            pull_before_positions = get_object_positions()
            # call_skill figures out proper signature
            obs, reward, done = call_skill(pull, env, task)
            if done:
                print("[Exploration] Task finished unexpectedly during pull()")
                return

            time.sleep(0.5)    # small delay for the drawer to settle

            # 4) Determine if drawer actually moved
            pull_after_positions = get_object_positions()
            drawer_moved = False
            if drawer_handle_name in pull_after_positions:
                before = np.array(pull_before_positions[drawer_handle_name])
                after  = np.array(pull_after_positions[drawer_handle_name])
                displacement = np.linalg.norm(after - before)
                print(f"[Exploration] Handle displacement = {displacement:.4f} m")
                drawer_moved = displacement > 0.01     # arbitrary threshold

            # Announce inferred predicate
            if drawer_moved:
                print("[Result] Drawer opened – therefore it was UNLOCKED.")
                print("[Predicate] (not (is-locked drawer1))  /  lock-known = true, state = unlocked")
            else:
                print("[Result] Drawer did NOT open – therefore it is LOCKED.")
                print("[Predicate] (is-locked drawer1)        /  lock-known = true, state = locked")

        # -----------------------------------------------------------------
        #  At this point we could proceed with the rest of the oracle plan
        #  (e.g., rotate gripper, move side, pick-drawer, pull, place object,
        #  etc.).  For the skeleton’s purposes, demonstrating the discovery
        #  of the missing predicate through the exploration phase is enough.
        # -----------------------------------------------------------------

        print("[Task] Skeleton task finished normally.")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()