# run_skeleton_task.py (Completed)

import time
import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# All low-level skills come from this module
from skill_code import *     # noqa

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


def _quat_equal(q1, q2, tol=1e-3):
    """Returns True if two quaternions represent (almost) the same rotation."""
    dot = abs(np.dot(q1, q2))
    return (2 * np.arccos(np.clip(dot, -1.0, 1.0))) < tol


def _get_zero_quat():
    """By convention we treat ‘zero_deg’ as identity quaternion (x,y,z,w) = (0,0,0,1)."""
    return np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32)


def _get_ninety_deg_quat(axis='z'):
    """90° rotation about the required axis."""
    if axis == 'x':
        return np.array([np.sin(np.pi/4), 0, 0, np.cos(np.pi/4)], dtype=np.float32)
    if axis == 'y':
        return np.array([0, np.sin(np.pi/4), 0, np.cos(np.pi/4)], dtype=np.float32)
    # default z
    return np.array([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)], dtype=np.float32)


def exploration_phase(env, task):
    """
    Small routine whose only purpose is to ensure the predicate:
        (rotated gripper zero_deg)
    holds in the real world.
    We simply rotate the gripper to the identity quaternion if it is not already there.
    """
    print("----- Exploration Phase : Ensuring (rotated gripper zero_deg) -----")
    obs = task.get_observation()
    current_quat = obs.gripper_pose[3:7]
    target_quat = _get_zero_quat()

    if _quat_equal(current_quat, target_quat):
        print("[Exploration] Gripper already at zero_deg – predicate satisfied.")
        return

    # Use the predefined low-level rotate skill
    try:
        rotate(
            env,
            task,
            target_quat=target_quat,
            max_steps=120,
            threshold=0.02,
            timeout=10.0
        )
        print("[Exploration] Successfully rotated to zero_deg – predicate now holds.")
    except Exception as e:
        print(f"[Exploration] WARNING – could not rotate to zero_deg. Error: {e}")


def high_level_plan(env, task, positions):
    """
    Example high-level controller that tries to open a drawer.
    Because we do not know the exact names that exist in the *real* scene, we
    perform a few guarded look-ups and silently skip steps that cannot be run.
    All movement / manipulation primitives come directly from skill_code.
    """
    # ------------------------------------------------------------------
    # 1. Optional – identify drawer handle position
    # ------------------------------------------------------------------
    drawer_handle_key_candidates = [
        'drawer_handle',
        'handle',
        'drawer_handle_0',
        'drawer0_handle',
    ]
    drawer_handle_pos = None
    for k in drawer_handle_key_candidates:
        if k in positions:
            drawer_handle_pos = positions[k]
            print(f"[Plan] Found drawer handle under key '{k}' @ {drawer_handle_pos}")
            break

    if drawer_handle_pos is None:
        print("[Plan] Drawer handle not found in provided positions dictionary – aborting HL plan.")
        return

    # ------------------------------------------------------------------
    # 2. Move towards the drawer handle and grasp it
    # ------------------------------------------------------------------
    try:
        print("[Plan] Attempting to PICK the drawer handle.")
        obs, reward, done = pick(
            env,
            task,
            target_pos=drawer_handle_pos,
            approach_distance=0.12,
            max_steps=150,
            threshold=0.01,
            approach_axis='z',
            timeout=12.0
        )
        if done:
            print("[Plan] Task ended while picking – abort.")
            return
    except Exception as e:
        print(f"[Plan] Pick failed – {e}")
        return

    # ------------------------------------------------------------------
    # 3. Pull action – open the drawer by moving gripper backwards
    # ------------------------------------------------------------------
    # We will simply call the pull primitive if it exists, otherwise manually
    # retreat along the global Y axis by 10 cm.
    try:
        if 'pull' in globals():
            print("[Plan] Using predefined 'pull' skill.")
            pull(
                env,
                task,
                distance=0.10,            # 10 cm
                max_steps=120,
                speed=0.01
            )
        else:
            print("[Plan] No pull() in skill_code – performing simple linear retreat.")
            # Build a simple linear motion away from the handle
            obs = task.get_observation()
            current_pos = np.array(obs.gripper_pose[:3])
            target_pos = current_pos + np.array([0.0, -0.10, 0.0])   # move backwards in Y
            move(
                env,
                task,
                target_pos=target_pos,
                max_steps=120,
                threshold=0.002,
                timeout=10.0
            )
    except Exception as e:
        print(f"[Plan] Pull/retreat failed – {e}")

    # ------------------------------------------------------------------
    # 4. Release the handle – place
    # ------------------------------------------------------------------
    try:
        print("[Plan] Releasing the handle (place back).")
        place(
            env,
            task,
            target_pos=drawer_handle_pos,
            approach_distance=0.12,
            max_steps=80,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
    except Exception as e:
        print(f"[Plan] Place failed – {e}")

    print("[Plan] High-level routine finished.")


def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        # Returns dict: name -> (x,y,z)
        positions = get_object_positions()

        # --------------------------------------------------------------
        # EXPLORATION  – make sure the missing predicate is satisfied
        # --------------------------------------------------------------
        exploration_phase(env, task)

        # --------------------------------------------------------------
        # MAIN HIGH-LEVEL PLAN
        # --------------------------------------------------------------
        high_level_plan(env, task, positions)

        # (Additional custom logic could be added here.)

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()