# run_skeleton_task.py (Completed Version – only predefined skills are used)

import time
import math
import inspect
import numpy as np

from pyrep.objects.shape import Shape              # ( kept – may be unused but part of skeleton )
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *                            # gives us: move, pick, place, rotate, pull, etc.
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Utility functions required by some skills (e.g. rotate from skill_code)
# ---------------------------------------------------------------------------
def normalize_quaternion(quat):
    quat = np.asarray(quat).astype(np.float64)
    norm = np.linalg.norm(quat)
    return quat / norm if norm > 0 else quat


def euler_from_quat(quat):
    """
    Convert quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw)
    """
    x, y, z, w = quat
    t0 = 2.0 * (w * x + y * z)
    t1 = 1.0 - 2.0 * (x * x + y * y)
    roll = math.atan2(t0, t1)

    t2 = 2.0 * (w * y - z * x)
    t2 = max(min(t2, 1.0), -1.0)
    pitch = math.asin(t2)

    t3 = 2.0 * (w * z + x * y)
    t4 = 1.0 - 2.0 * (y * y + z * z)
    yaw = math.atan2(t3, t4)
    return (roll, pitch, yaw)


def quat_from_axis_angle(axis, angle):
    """
    Helper to build a quaternion from an axis and angle.
    """
    axis = np.asarray(axis, dtype=np.float64)
    axis = axis / np.linalg.norm(axis)
    s = math.sin(angle / 2.0)
    w = math.cos(angle / 2.0)
    x, y, z = axis * s
    return np.array([x, y, z, w], dtype=np.float64)


# ---------------------------------------------------------------------------
# Robust skill-invocation helper
# ---------------------------------------------------------------------------
def safe_call(skill_func, *pos_args, **kw_args):
    """
    Call a skill function while matching only the parameters that exist in its
    signature.  This prevents crashes due to unknown keywords.
    """
    sig = inspect.signature(skill_func)
    allowed_kw = {k: v for k, v in kw_args.items() if k in sig.parameters}
    try:
        return skill_func(*pos_args, **allowed_kw)
    except Exception as exc:
        print(f"[safe_call] Exception during {skill_func.__name__}: {exc}")
        # A unified (obs, reward, done) triple so the caller can keep going
        return None, 0.0, False


# ---------------------------------------------------------------------------
# Main entry
# ---------------------------------------------------------------------------
def run_skeleton_task():
    """High-level routine that ❶ explores the scene to discover any missing
    predicate information and ❷ demonstrates the usage of only predefined
    skills to manipulate the environment."""
    print("===== Starting Skeleton Task =====")

    # -----------------------------------------------------------------------
    #  Environment boot-up
    # -----------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # -------------------------------------------------------------------
        #  (Optional) video capture
        # -------------------------------------------------------------------
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # -------------------------------------------------------------------
        #  Gather a list of all known objects & their positions
        # -------------------------------------------------------------------
        positions = get_object_positions()           # dict { name : (x,y,z) }
        if not positions:
            print("[Task] No object positions found – ending early.")
            return

        print(f"[Task] Retrieved {len(positions)} objects from world-model.")

        # -------------------------------------------------------------------
        #  1) EXPLORATION PHASE  ------------------------------------------------
        # -------------------------------------------------------------------
        # Goal: visit every object so the system can collect information such
        # as identified/weight-known/lock-known etc.  This mirrors the
        # exploration domain supplied in the prompt.
        for obj_name, obj_pos in positions.items():
            print(f"\n[Exploration] Visiting object: {obj_name} @ {obj_pos}")

            # ---------- move close to the object ----------
            obs, reward, done = safe_call(
                move, env, task,
                target_pos=obj_pos,
                approach_distance=0.20, max_steps=100,
                threshold=0.02, timeout=10.0
            )
            if done:
                print("[Exploration] Task signalled completion while moving.")
                return

            # ---------- attempt to pick the object ----------
            obs, reward, done = safe_call(
                pick, env, task,
                target_pos=obj_pos,
                approach_distance=0.15, max_steps=120,
                threshold=0.01, approach_axis='z', timeout=10.0
            )
            if done:
                print("[Exploration] Task ended after pick.")
                return

            # ---------- Immediately place it back (same pose) ----------
            obs, reward, done = safe_call(
                place, env, task,
                target_pos=obj_pos,
                approach_distance=0.18, max_steps=120,
                threshold=0.02, approach_axis='z', timeout=10.0
            )
            if done:
                print("[Exploration] Task ended after place.")
                return

            # -------------------------------------------------------------------
            #  Drawer-specific interaction: rotate & pull
            # -------------------------------------------------------------------
            if "drawer" in obj_name.lower():
                print(f"[Exploration] Detected drawer-like object: {obj_name}")

                # --- ensure gripper rotated 90° about Z so it is perpendicular ---
                ninety_quat = quat_from_axis_angle((0, 0, 1), math.pi / 2.0)
                obs, reward, done = safe_call(
                    rotate, env, task,
                    target_quat=ninety_quat,
                    max_steps=150, threshold=0.05, timeout=10.0
                )
                if done:
                    print("[Exploration] Task ended during rotate.")
                    return

                # --- attempt pull (if handle is held) ---
                obs, reward, done = safe_call(pull, env, task)
                if done:
                    print("[Exploration] Task ended during pull.")
                    return

        print("\n[Exploration] Completed sweep of every known object.")

        # -------------------------------------------------------------------
        # 2) MAIN GOAL-ORIENTED BEHAVIOUR PLACEHOLDER
        #    (In a competition setting the oracle plan would continue here)
        # -------------------------------------------------------------------
        print("[Task] All exploratory predicates should now be known.")
        print("[Task] No further high-level plan implemented (placeholder).")

    finally:
        # -------------------------------------------------------------------
        #  Environment clean-up
        # -------------------------------------------------------------------
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()