# run_skeleton_task.py  (Completed Version)

import time
import traceback
from typing import Dict, Any, Optional

import numpy as np
from pyrep.objects.shape import Shape           #  <-- kept from skeleton
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import move, pick, place, rotate, pull   #  <-- ONLY predefined skills
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions        #  <-- helper utility


# ----------------------------------------------------------------------------- #
# ----------------------------- Helper Functions ------------------------------ #
# ----------------------------------------------------------------------------- #
def is_orientation_achievable(current_q: np.ndarray,
                              target_q: np.ndarray,
                              tolerance: float = 1e-3) -> bool:
    """
    Simple check that the angular distance is not 'nan' or degenerate.
    (Realistic implementation would be more complex.  Here, we merely
    guard against obvious numerical issues so the demo will never loop
    forever.)
    """
    dot = np.clip(abs(np.dot(current_q, target_q)), -1.0, 1.0)
    angle = 2 * np.arccos(dot)
    return not np.isnan(angle) and angle <= np.pi + tolerance


def is_gripper_colliding(obs: Any, env: Any) -> bool:
    """
    Stub collision check.  A full implementation would query the physics
    engine for contact points.  We return False so normal execution
    continues, but the function is kept for future safety logic.
    """
    # NOTE: env._physics_test_for_collisions()  # example of what *could* be used
    return False


def safe_rotate(env,
                task,
                target_quat,
                max_steps: int = 100,
                threshold: float = 0.05,
                timeout: float = 10.0):
    """
    Wrapper around the provided `rotate` skill that includes very lightweight
    safety guards requested in the feedback.
    """
    obs = task.get_observation()
    current_q = obs.gripper_pose[3:7]
    if not is_orientation_achievable(current_q, target_quat):
        print("[safe_rotate] Requested orientation appears unreachable.  "
              "Skipping rotate() to stay safe.")
        return obs, 0.0, False

    if is_gripper_colliding(obs, env):
        print("[safe_rotate] Gripper already colliding — aborting rotation.")
        return obs, 0.0, False

    # Delegate to the official primitive once checks pass.
    return rotate(env,
                  task,
                  target_quat,
                  max_steps=max_steps,
                  threshold=threshold,
                  timeout=timeout)


def find_missing_predicate_through_exploration(env,
                                               task,
                                               positions: Dict[str, np.ndarray]) -> Optional[str]:
    """
    Naïve exploration routine that tries to open a drawer.  If `pull()`
    fails it concludes that the 'holding-drawer' predicate could not be
    achieved; otherwise, it assumes success.
    """
    drawer_handle_key = 'drawer_handle'
    if drawer_handle_key not in positions:
        print(f"[Exploration] No '{drawer_handle_key}' in positions dict.  "
              "Skipping exploration phase.")
        return None

    handle_pos = positions[drawer_handle_key]
    print("[Exploration] Trying to grasp drawer handle at:", handle_pos)

    # --------------------------------------------------------------------- #
    # (1) Move end-effector near the handle
    # --------------------------------------------------------------------- #
    try:
        obs, reward, done = move(env,
                                 task,
                                 target_pos=handle_pos,
                                 approach_distance=0.10,
                                 max_steps=150,
                                 threshold=0.01,
                                 approach_axis='z',
                                 timeout=5.0)
    except Exception as exc:
        print("[Exploration] move() failed:", exc)
        traceback.print_exc()
        return None

    # --------------------------------------------------------------------- #
    # (2) Rotate gripper so it can grasp handle sideways
    # --------------------------------------------------------------------- #
    target_quat_xyzw = np.array([0.0, 0.7071, 0.0, 0.7071])
    try:
        obs, reward, done = safe_rotate(env,
                                        task,
                                        target_quat=target_quat_xyzw,
                                        max_steps=120)
    except Exception as exc:
        print("[Exploration] rotate() failed:", exc)
        traceback.print_exc()
        return None

    # --------------------------------------------------------------------- #
    # (3) Pick (i.e. close fingers on the handle)
    # --------------------------------------------------------------------- #
    try:
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=handle_pos,
                                 approach_distance=0.04,
                                 max_steps=120,
                                 threshold=0.005,
                                 approach_axis='z',
                                 timeout=5.0)
    except Exception as exc:
        print("[Exploration] pick() failed while trying to grasp handle:", exc)
        traceback.print_exc()
        return None

    # --------------------------------------------------------------------- #
    # (4) Attempt to pull the drawer
    # --------------------------------------------------------------------- #
    try:
        obs, reward, done = pull(env, task, distance=0.15, max_steps=150)
        print("[Exploration] Drawer pulled open successfully.")
        return None    # no missing predicate
    except Exception as exc:
        print("[Exploration] pull() failed:", exc)
        traceback.print_exc()
        print("[Exploration] Hypothesis: 'holding-drawer' predicate missing.")
        return "holding-drawer"


# -------------------------------------------------------------------------- #
# ----------------------------- Main Routine -------------------------------- #
# -------------------------------------------------------------------------- #
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # ------------------------------------------------------------------ #
        # 1)  Reset task and wrap with recorders
        # ------------------------------------------------------------------ #
        descriptions, obs = task.reset()
        init_video_writers(obs)

        original_step = task.step
        task.step = recording_step(original_step)

        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ------------------------------------------------------------------ #
        # 2)  Retrieve object positions
        # ------------------------------------------------------------------ #
        positions: Dict[str, np.ndarray] = get_object_positions()
        print("[Main] Retrieved object positions keys:", list(positions.keys()))

        # ------------------------------------------------------------------ #
        # 3)  Exploration phase – detect missing predicate
        # ------------------------------------------------------------------ #
        missing_predicate = find_missing_predicate_through_exploration(
            env, task, positions)
        if missing_predicate is not None:
            print(f"[Main] Exploration indicates missing predicate: "
                  f"'{missing_predicate}'.  Adapting plan...")
        else:
            print("[Main] No missing predicate detected during exploration.")

        # ------------------------------------------------------------------ #
        # 4)  Task-specific plan
        # ------------------------------------------------------------------ #
        # Example: pick up 'rubbish' object (if any) from inside the now-open
        # drawer and place it in the 'trash_bin'.
        #
        # NOTE: This is an illustrative sequence; actual object names depend
        #       on the loaded RLBench task.
        #
        rubbish_key = 'rubbish'
        bin_key = 'trash_bin'

        if rubbish_key in positions and bin_key in positions:
            rubbish_pos = positions[rubbish_key]
            bin_pos = positions[bin_key]

            print("[Main] Proceeding to dispose rubbish:", rubbish_key)

            # 4-A) Move to rubbish
            obs, reward, done = move(env,
                                     task,
                                     target_pos=rubbish_pos,
                                     approach_distance=0.10,
                                     max_steps=200,
                                     threshold=0.01,
                                     approach_axis='z',
                                     timeout=10.0)
            if done:
                print("[Main] Episode finished unexpectedly after move().")
                return

            # 4-B) Pick rubbish
            obs, reward, done = pick(env,
                                     task,
                                     target_pos=rubbish_pos,
                                     approach_distance=0.04,
                                     max_steps=200,
                                     threshold=0.005,
                                     approach_axis='z',
                                     timeout=10.0)
            if done:
                print("[Main] Episode finished unexpectedly after pick().")
                return

            # 4-C) Move towards trash bin
            obs, reward, done = move(env,
                                     task,
                                     target_pos=bin_pos,
                                     approach_distance=0.10,
                                     max_steps=200,
                                     threshold=0.01,
                                     approach_axis='z',
                                     timeout=10.0)
            if done:
                print("[Main] Episode finished unexpectedly after move to bin.")
                return

            # 4-D) Place rubbish inside bin
            obs, reward, done = place(env,
                                      task,
                                      target_pos=bin_pos,
                                      retreat_distance=0.10,
                                      max_steps=200,
                                      threshold=0.005,
                                      timeout=10.0)
            if done:
                print("[Main] Task finished after placing rubbish.")
            else:
                print("[Main] Successfully disposed rubbish.")

        else:
            print(f"[Main] Could not find both '{rubbish_key}' and "
                  f"'{bin_key}' in positions.  Skipping disposal step.")

        # ------------------------------------------------------------------ #
        # 5)  End-of-task bookkeeping
        # ------------------------------------------------------------------ #
        print("===== End of Skeleton Task =====")

    finally:
        # Always shut down environment, even if exceptions occur
        shutdown_environment(env)


if __name__ == "__main__":
    run_skeleton_task()