import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment

# use all predefined skills exactly as provided
from skill_code import *           # move, pick, place, rotate, pull, …

from video import (
    init_video_writers,
    recording_step,
    recording_get_observation,
)

from object_positions import get_object_positions


# ----------------------------------------------------------------------
#  Helper utilities
# ----------------------------------------------------------------------
def _assert_position(positions, key):
    """
    Fail fast (with an informative message) if the requested key is
    missing from `object_positions`.  Returns the value as float32 np.array.
    """
    if key not in positions:
        raise KeyError(
            f"[PLAN] '{key}' not found in object_positions(). "
            f"Available keys: {list(positions.keys())}"
        )
    return np.asarray(positions[key], dtype=np.float32)


def _choose_pull_axis(start, target):
    """
    Pick the major axis (x, y, or z) that differs the most between
    `start` and `target`.  This offers a plausible guess for the drawer’s
    pull direction without hard-coding a specific axis.
    """
    delta = target - start
    axis_idx = int(np.argmax(np.abs(delta)))
    sign = "" if delta[axis_idx] >= 0 else "-"
    return f"{sign}{'xyz'[axis_idx]}"


def _first_available_position(positions, candidates, alias):
    """
    Return the first candidate key that exists in `positions`.
    Raise an error if none are found.
    """
    for k in candidates:
        if k in positions:
            return np.asarray(positions[k], dtype=np.float32)
    raise KeyError(
        f"[PLAN] None of the candidate keys for {alias} were found in "
        f"object_positions(). Tried: {candidates}"
    )


# ----------------------------------------------------------------------
#  Main entry-point
# ----------------------------------------------------------------------
def run_skeleton_task():
    print("==========  STARTING  COMBINED TASK  ==========")

    # ------------------------------------------------------------------
    # 1)  Environment initialisation
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional video capture
        init_video_writers(obs)
        task.step = recording_step(task.step)              # wrap for video
        task.get_observation = recording_get_observation(  # wrap for video
            task.get_observation
        )

        # ------------------------------------------------------------------
        # 2)  Fetch all object positions once at the beginning
        # ------------------------------------------------------------------
        positions = get_object_positions()

        # Drawer related keys (bottom drawer only, as per specification)
        bottom_side_pos   = _assert_position(positions, "bottom_side_pos")
        bottom_anchor_pos = _assert_position(positions, "bottom_anchor_pos")

        # We treat the anchor position as the drawer-handle position
        drawer_handle_pos = bottom_anchor_pos

        # Rubbish & bin – fall-back to a list of likely keys
        rubbish_pos = _first_available_position(
            positions,
            candidates=["rubbish", "item3", "item2", "item1"],
            alias="rubbish",
        )
        bin_pos = _first_available_position(
            positions,
            candidates=["bin", "bin_pos", "trash_bin"],
            alias="bin",
        )

        # ------------------------------------------------------------------
        # 3)  STEP-BY-STEP PLAN EXECUTION  (matches provided specification)
        # ------------------------------------------------------------------

        # --------------------------------------------------------------
        # Step-1 : rotate  (gripper to 90° about world-Z)
        # --------------------------------------------------------------
        obs = task.get_observation()
        current_quat = normalize_quaternion(obs.gripper_pose[3:7])
        rotate90_quat = (
            R.from_euler("z", 90.0, degrees=True) * R.from_quat(current_quat)
        ).as_quat()
        print("[PLAN] Step-1 : rotate gripper 90° about world-Z")
        obs, reward, done = rotate(
            env,
            task,
            target_quat=rotate90_quat,
            max_steps=150,
        )
        if done:
            print("[PLAN] Task ended unexpectedly after rotate.")
            return

        # --------------------------------------------------------------
        # Step-2 : move to the side position of the bottom drawer
        # --------------------------------------------------------------
        print("[PLAN] Step-2 : move to bottom_side_pos →", bottom_side_pos)
        obs, reward, done = move(env, task, target_pos=bottom_side_pos)
        if done:
            print("[PLAN] Task ended unexpectedly after move-to-side.")
            return

        # --------------------------------------------------------------
        # Step-3 : move from side-pos to anchor-pos (handle)
        # --------------------------------------------------------------
        print("[PLAN] Step-3 : move to bottom_anchor_pos →", bottom_anchor_pos)
        obs, reward, done = move(env, task, target_pos=bottom_anchor_pos)
        if done:
            print("[PLAN] Task ended unexpectedly after move-to-anchor.")
            return

        # --------------------------------------------------------------
        # Step-4 : pick the drawer handle (grasp)
        # --------------------------------------------------------------
        print("[PLAN] Step-4 : pick drawer handle at", drawer_handle_pos)
        obs, reward, done = pick(env, task, target_pos=drawer_handle_pos)
        if done:
            print("[PLAN] Task ended unexpectedly after pick-drawer.")
            return

        # --------------------------------------------------------------
        # Step-5 : pull the drawer open
        # --------------------------------------------------------------
        print("[PLAN] Step-5 : pull the drawer open")
        pull_axis = _choose_pull_axis(bottom_anchor_pos, bottom_side_pos)
        obs, reward, done = pull(
            env,
            task,
            pull_distance=0.25,      # 25 cm is usually enough
            pull_axis=pull_axis,
        )
        if done:
            print("[PLAN] Task ended unexpectedly after pull.")
            return

        # --------------------------------------------------------------
        # Step-6 : pick the rubbish
        # --------------------------------------------------------------
        print("[PLAN] Step-6 : approach and pick rubbish at", rubbish_pos)
        # Pre-position slightly above to avoid collisions
        obs, reward, done = move(
            env,
            task,
            target_pos=rubbish_pos + np.array([0.0, 0.0, 0.15]),
        )
        if done:
            print("[PLAN] Task ended unexpectedly during pre-pick move.")
            return
        obs, reward, done = pick(env, task, target_pos=rubbish_pos)
        if done:
            print("[PLAN] Task ended unexpectedly after pick-rubbish.")
            return

        # --------------------------------------------------------------
        # Step-7 : place the rubbish in the bin
        # --------------------------------------------------------------
        print("[PLAN] Step-7 : approach and place rubbish into bin at", bin_pos)
        # Move above bin first
        obs, reward, done = move(
            env,
            task,
            target_pos=bin_pos + np.array([0.0, 0.0, 0.20]),
        )
        if done:
            print("[PLAN] Task ended unexpectedly during pre-place move.")
            return
        obs, reward, done = place(env, task, target_pos=bin_pos)
        if done:
            print("[PLAN] Task completed (done=True) after placing rubbish!")
        else:
            print("[PLAN] Plan executed. Success may be indicated by reward signal.")

        # --------------------------------------------------------------
        # End of plan
        # --------------------------------------------------------------

    finally:
        shutdown_environment(env)
        print("==========  TASK   COMPLETE  ==========")


# ----------------------------------------------------------------------
#  Script execution guard
# ----------------------------------------------------------------------
if __name__ == "__main__":
    run_skeleton_task()