import numpy as np
from pyrep.objects.shape import Shape          # noqa: F401  (kept for possible side‑effects in RLBench)
from pyrep.objects.proximity_sensor import ProximitySensor   # noqa: F401
import traceback

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# ====== Skill functions (already provided, DO NOT REDEFINE) ======
from skill_code import rotate, move, pick, pull, place


# ---------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------
TOMATO_CANDIDATES = ['item1', 'item2', 'item3',        # names in the observation
                     'tomato1', 'tomato2', 'tomato3']  # in case they are named tomato*

DRAWER_ANCHOR_KEYS = ['bottom_anchor_pos', 'middle_anchor_pos', 'top_anchor_pos']
DRAWER_SIDE_KEYS   = ['bottom_side_pos',   'middle_side_pos',   'top_side_pos']


def _safe_get(positions, keys, description):
    """
    Utility to fetch the first existing key from a list.
    Raises a RuntimeError if none are found.
    """
    for k in keys:
        if k in positions:
            return positions[k]
    raise RuntimeError(f"[Task] Cannot find {description} in object_positions. "
                       f"Checked keys: {keys}")


# ---------------------------------------------------------------
# Main task runner
# ---------------------------------------------------------------
def run_skeleton_task():
    print("\n==========  Start Combined‑Task Script  ==========\n")

    # 1) ----------------------------------------------------------------
    # Environment set‑up
    # -------------------------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset simulation
        descriptions, obs = task.reset()

        # Optional video capture
        init_video_writers(obs)
        task.step             = recording_step(task.step)
        task.get_observation  = recording_get_observation(task.get_observation)

        # 2) ------------------------------------------------------------
        # Fetch object positions
        # ---------------------------------------------------------------
        positions = get_object_positions()

        drawer_anchor_pos = _safe_get(positions, DRAWER_ANCHOR_KEYS, "drawer anchor position")
        drawer_side_pos   = _safe_get(positions, DRAWER_SIDE_KEYS,   "drawer side position")
        plate_pos         = _safe_get(positions, ['plate'],          "plate position")

        tomato_names = [name for name in TOMATO_CANDIDATES if name in positions]
        if len(tomato_names) == 0:
            raise RuntimeError("[Task] No tomato objects were found!")
        print(f"[Task] Tomatoes to move  : {tomato_names}")
        print(f"[Task] Drawer anchor pos : {drawer_anchor_pos}")
        print(f"[Task] Plate position    : {plate_pos}")

        # ---------------------------------------------------------------
        # 3)  Execute high‑level plan  (Spec‑guided)
        # ---------------------------------------------------------------

        # STEP‑1  : rotate (keep current orientation – satisfies skill call)
        current_quat = obs.gripper_pose[3:7]
        print("\n--- Step 1 : rotate ---")
        obs, reward, done = rotate(env, task, target_quat=current_quat)
        if done:
            print("[Task] Finished unexpectedly after rotate.")
            return

        # STEP‑2  : move to drawer handle region
        print("\n--- Step 2 : move (drawer_handle_pos) ---")
        obs, reward, done = move(env, task, target_pos=drawer_anchor_pos)
        if done:
            print("[Task] Finished unexpectedly after move to drawer.")
            return

        # STEP‑3  : pick drawer handle
        print("\n--- Step 3 : pick (drawer_handle_pos) ---")
        obs, reward, done = pick(env, task, target_pos=drawer_anchor_pos,
                                 approach_axis='z')     # approach from above
        if done:
            print("[Task] Finished unexpectedly after picking drawer handle.")
            return

        # STEP‑4  : pull to open drawer  (0.2 m along −x)
        print("\n--- Step 4 : pull (0.2 m, -x) ---")
        obs, reward, done = pull(env, task, pull_distance=0.20, pull_axis='-x')
        if done:
            print("[Task] Finished unexpectedly after pulling drawer.")
            return

        # Optional: move away slightly to avoid collision with opened drawer
        safe_retract = drawer_side_pos + np.array([0.0, 0.0, 0.10])
        obs, reward, done = move(env, task, target_pos=safe_retract)
        if done:
            print("[Task] Finished unexpectedly after safe retract.")
            return

        # =================================================================
        # Handle EACH tomato
        # =================================================================
        for idx, t_name in enumerate(tomato_names, start=1):
            print(f"\n=== Tomato #{idx} : {t_name} ===")
            tomato_pos = positions[t_name]

            # STEP‑5 / 9 / ... : move to tomato
            print(f"--- move to {t_name} ---")
            obs, reward, done = move(env, task, target_pos=tomato_pos)
            if done:
                print("[Task] Finished unexpectedly during approach to tomato.")
                return

            # STEP‑6 / 10 / ... : pick tomato
            print(f"--- pick {t_name} ---")
            obs, reward, done = pick(env, task, target_pos=tomato_pos,
                                     approach_axis='z')
            if done:
                print("[Task] Finished unexpectedly during pick tomato.")
                return

            # STEP‑7 / 11 / ... : move to plate
            print("--- move to plate ---")
            obs, reward, done = move(env, task, target_pos=plate_pos)
            if done:
                print("[Task] Finished unexpectedly during move to plate.")
                return

            # STEP‑8 / 12 / ... : place tomato on plate
            print("--- place on plate ---")
            obs, reward, done = place(env, task, target_pos=plate_pos,
                                      approach_axis='z')
            if done:
                print("[Task] Task signalled completion.")
                break

        print("\n==========  Task Finished Successfully  ==========")

    except Exception as e:
        print("\n[Task] Exception caught! Detailed traceback below:\n")
        traceback.print_exc()
    finally:
        # Always close the simulation cleanly
        shutdown_environment(env)
        print("\n==========  Environment shutdown completed  ==========")


# -----------------------------------------------------------------
# Entry‑point guard
# -----------------------------------------------------------------
if __name__ == "__main__":
    run_skeleton_task()
