import numpy as np
from scipy.spatial.transform import Rotation as R
import traceback

from env import setup_environment, shutdown_environment

# keep the wildcard import so every primitive in skill_code is available
from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# -----------------------------------------------------------
# Helper utilities
# -----------------------------------------------------------
def safe_skill_call(skill_fn, *args, **kwargs):
    """
    Wrapper that safely executes a skill and prints a readable
    stack-trace if something goes wrong.  It also propagates the
    (obs, reward, done) tuple so the caller can immediately decide
    whether the task has finished.
    """
    try:
        return skill_fn(*args, **kwargs)
    except Exception:
        print("\n[ERROR] Exception during skill execution:")
        traceback.print_exc()
        # Make sure we always raise so the outer `finally` block
        # still shuts the environment down properly.
        raise


def run_skeleton_task():
    """Main entry-point that executes the oracle plan from the specification."""
    print("===== Starting Skeleton Task =====")

    # -------------------------------------------------------
    # 1) Environment set-up
    # -------------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset and get the very first observation
        descriptions, obs = task.reset()

        # Optional: enable video recording helpers
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------------------------------------------
        # 2) Acquire all relevant object positions
        # ---------------------------------------------------
        # RLBench names → 3-D positions (np.ndarray)
        positions = get_object_positions()

        # Logical-name → RLBench-object mapping table
        #
        #   PDDL / natural-language             RLBench scene object
        # -----------------------------------------------------------------
        #   nowhere-pos                    →     waypoint1            (float pt)
        #   side-pos-bottom                →     bottom_side_pos      (drawer side)
        #   anchor-pos-bottom              →     bottom_anchor_pos    (drawer handle)
        #   rubbish                        →     rubbish              (piece of trash)
        #   bin                            →     bin                  (trash can)
        #
        name_map = {
            "nowhere-pos": "waypoint1",
            "side-pos-bottom": "bottom_side_pos",
            "anchor-pos-bottom": "bottom_anchor_pos",
            "rubbish": "rubbish",
            "bin": "bin",
        }

        def p(name: str) -> np.ndarray:
            """Resolve a logical PDDL position name → numpy[3] world-position."""
            try:
                return np.asarray(positions[name_map[name]])
            except KeyError:
                raise KeyError(
                    f"[Mapping] No position found for '{name}'.   "
                    f"Check `name_map` or `object_positions`."
                )

        # ---------------------------------------------------
        # 3) Execute the oracle plan step-by-step
        # ---------------------------------------------------
        done = False

        # STEP 1 : rotate(gripper, zero_deg, ninety_deg)
        print("\n[STEP 1] rotate gripper: zero_deg → ninety_deg")
        # The target orientation is a +90° rotation about Z.
        target_quat = R.from_euler("z", 90, degrees=True).as_quat()
        obs, reward, done = safe_skill_call(
            rotate, env, task, target_quat=target_quat
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 1 ― task finished early.")
            return

        # STEP 2 : move-to-side
        print("\n[STEP 2] move gripper: nowhere-pos → side-pos-bottom")
        obs, reward, done = safe_skill_call(
            move, env, task, target_pos=p("side-pos-bottom")
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 2 ― task finished early.")
            return

        # STEP 3 : move-to-anchor
        print("\n[STEP 3] move gripper: side-pos-bottom → anchor-pos-bottom")
        obs, reward, done = safe_skill_call(
            move, env, task, target_pos=p("anchor-pos-bottom")
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 3 ― task finished early.")
            return

        # STEP 4 : pick-drawer (we reuse generic pick to grasp the handle)
        print("\n[STEP 4] pick drawer handle at anchor-pos-bottom")
        obs, reward, done = safe_skill_call(
            pick,
            env,
            task,
            target_pos=p("anchor-pos-bottom"),
            approach_distance=0.10,   # a bit closer than object picks
            approach_axis="z",
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 4 ― task finished early.")
            return

        # STEP 5 : pull the drawer open
        print("\n[STEP 5] pull drawer to open")
        # Heuristic: pull 0.20 m along +X.  (This matches most RLBench drawer setups.)
        obs, reward, done = safe_skill_call(
            pull,
            env,
            task,
            pull_distance=0.20,
            pull_axis="x",
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 5 ― task finished early.")
            return

        # STEP 6 : pick the rubbish
        print("\n[STEP 6] pick rubbish from table")
        obs, reward, done = safe_skill_call(
            pick,
            env,
            task,
            target_pos=p("rubbish"),
            approach_distance=0.15,
            approach_axis="z",
        )
        if done:
            print("[Info] Environment signalled `done` after STEP 6 ― task finished early.")
            return

        # STEP 7 : place rubbish into bin
        print("\n[STEP 7] place rubbish into bin")
        obs, reward, done = safe_skill_call(
            place,
            env,
            task,
            target_pos=p("bin"),
            approach_distance=0.15,
            approach_axis="z",
        )

        # ---------------------------------------------------
        # 4) Final status message
        # ---------------------------------------------------
        if done:
            print("[Success] Task completed and environment returned done=True.")
        else:
            # RLBench tasks sometimes keep running even on success; give user feedback anyway.
            print("[Success] Final step executed.  Environment did not signal done, "
                  "but the high-level goal should now be satisfied.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()