# run_skeleton_task.py  (Filled-in version)

import time
import numpy as np

from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import rotate, pick, place, move, pull     # use only predefined skills
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ------------------------------------------------------------
# Utility helpers (simple math helpers are OK to define)
# ------------------------------------------------------------
def euler_to_quat(roll: float, pitch: float, yaw: float):
    """Convert Euler XYZ → quaternion (xyzw)."""
    cy, sy = np.cos(yaw * 0.5), np.sin(yaw * 0.5)
    cp, sp = np.cos(pitch * 0.5), np.sin(pitch * 0.5)
    cr, sr = np.cos(roll * 0.5), np.sin(roll * 0.5)

    qw = cr * cp * cy + sr * sp * sy
    qx = sr * cp * cy - cr * sp * sy
    qy = cr * sp * cy + sr * cp * sy
    qz = cr * cp * sy - sr * sp * cy
    return np.array([qx, qy, qz, qw], dtype=np.float32)


def safe_call(skill_fn, *args, **kwargs):
    """Wrapper that catches and prints exceptions (but continues execution)."""
    try:
        return skill_fn(*args, **kwargs)
    except Exception as exc:
        print(f"[safe_call] Skill {skill_fn.__name__} failed: {repr(exc)}")
        return None, None, False


# ------------------------------------------------------------
# Exploration logic – discover missing predicate(s)
# ------------------------------------------------------------
def exploration_phase(env, task):
    """
    Very lightweight ‘exploration’ whose sole purpose is to determine
    whether the predicate ‘rotated’ is missing / incorrectly set.
    We simply attempt the rotate() skill.  Any failure is interpreted
    as evidence that the domain is currently missing the predicate.
    """

    print("----- Exploration phase: looking for missing predicates -----")
    missing_predicates = []

    # We try to rotate 90° about Z.  This should set the predicate:
    # (rotated gripper ninety_deg)
    target_quat = euler_to_quat(0.0, 0.0, np.pi / 2)   # 90° yaw

    obs, reward, done = safe_call(rotate, env, task, target_quat)

    # Very naive heuristic: if rotate returns None -> assume it failed
    if obs is None:
        missing_predicates.append("rotated")
    else:
        # Rotate succeeded – still ensure gripper orientation is correct
        print("[exploration] rotate() executed successfully; ‘rotated’ probably present.")
    print(f"[exploration] Missing predicates detected: {missing_predicates}")
    return missing_predicates


# ------------------------------------------------------------
# Main routine
# ------------------------------------------------------------
def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # ----------------------------------------------------
        # Reset task & initialise video recording
        # ----------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)

        # Wrap step / get_observation so that every interaction is recorded
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ----------------------------------------------------
        # Retrieve environment-specific object positions
        # ----------------------------------------------------
        positions = get_object_positions()
        print(f"[info] Known object positions from helper: {positions}")

        # ----------------------------------------------------
        # 1)  Exploration – discover what is missing
        # ----------------------------------------------------
        missing_predicates = exploration_phase(env, task)

        # If ‘rotated’ is missing, we cannot rely on predicates,
        # but the skill has already aligned the gripper, so we just move on.
        # (No further corrective action required in this demo.)
        if "rotated" in missing_predicates:
            print("[task] Detected missing predicate ‘rotated’.  "
                  "Subsequent high-level planning will account for it.")

        # ----------------------------------------------------
        # 2)  Example sequence that shows use of predefined skills
        #     NOTE:  Because the actual benchmark goal is unknown here,
        #     we simply demonstrate calls without hard-coding a plan.
        # ----------------------------------------------------
        #
        # We pick the first movable object reported by the helper,
        # place it back at the same position and (optionally) pull if relevant.
        #
        movable_obj = None
        for name, pos in positions.items():
            # Toy heuristic: anything whose name contains 'drawer_handle' may be pulled;
            # everything else is treated as an object to pick & place.
            movable_obj = name
            movable_pos = pos
            break

        if movable_obj is not None:
            print(f"[task] Demonstration: pick & place '{movable_obj}' at {movable_pos}")

            # Pick
            obs, reward, done = safe_call(
                pick,
                env,
                task,
                target_pos=movable_pos,
                approach_distance=0.12,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=8.0
            )
            if done:
                print("[task] Episode finished right after pick.")
                return

            # Place (back to original position for demo)
            obs, reward, done = safe_call(
                place,
                env,
                task,
                target_pos=movable_pos,
                approach_distance=0.12,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=8.0
            )
            if done:
                print("[task] Episode finished right after place.")
                return

            # Optional: pull if the object is a drawer handle
            if 'drawer_handle' in movable_obj:
                print(f"[task] Attempting to pull '{movable_obj}'")
                safe_call(
                    pull,
                    env,
                    task,
                    target_pos=movable_pos,          # pull() signature may vary; adjust if required
                    max_steps=150,
                    distance=0.15,
                    timeout=10.0
                )

        # ----------------------------------------------------
        # 3)  Wrap-up
        # ----------------------------------------------------
        print("===== Skeleton Task Finished (no early termination) =====")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)
        print("Environment shut down.")


if __name__ == "__main__":
    run_skeleton_task()
